1use std::{f32::consts::SQRT_2, fs::File, num::NonZeroUsize, time::Duration};
4
5use object_pool::Pool;
6use rubato::{FastFixedIn, Resampler, ResamplerConstructionError};
7use symphonia::{
8 core::{
9 audio::{AudioBufferRef, SampleBuffer, SignalSpec},
10 codecs::{CODEC_TYPE_NULL, DecoderOptions},
11 errors::Error,
12 formats::{FormatOptions, FormatReader},
13 io::{MediaSourceStream, MediaSourceStreamOptions},
14 meta::MetadataOptions,
15 probe::Hint,
16 units,
17 },
18 default::get_probe,
19};
20
21use crate::{ResampledAudio, SAMPLE_RATE, errors::AnalysisError, errors::AnalysisResult};
22
23use super::Decoder;
24
25const MAX_DECODE_RETRIES: usize = 3;
26const CHUNK_SIZE: usize = 4096;
27
28#[doc(hidden)]
30pub struct SymphoniaSource {
31 decoder: Box<dyn symphonia::core::codecs::Decoder>,
32 current_span_offset: usize,
33 format: Box<dyn FormatReader>,
34 total_duration: Option<Duration>,
35 buffer: SampleBuffer<f32>,
36 spec: SignalSpec,
37}
38
39impl SymphoniaSource {
40 pub fn new(mss: MediaSourceStream) -> Result<Self, Error> {
47 Self::init(mss)?.ok_or(Error::DecodeError("No Streams"))
48 }
49
50 fn init(mss: MediaSourceStream) -> symphonia::core::errors::Result<Option<Self>> {
51 let hint = Hint::new();
52 let format_opts = FormatOptions {
53 enable_gapless: true,
54 ..Default::default()
55 };
56 let metadata_opts = MetadataOptions::default();
57 let mut probed_format = get_probe()
58 .format(&hint, mss, &format_opts, &metadata_opts)?
59 .format;
60
61 let Some(stream) = probed_format.default_track() else {
62 return Ok(None);
63 };
64
65 let track = probed_format
67 .tracks()
68 .iter()
69 .find(|t| t.codec_params.codec != CODEC_TYPE_NULL)
70 .ok_or(Error::Unsupported("No track with supported codec"))?;
71
72 let track_id = track.id;
73
74 let mut decoder = symphonia::default::get_codecs()
75 .make(&track.codec_params, &DecoderOptions::default())?;
76 let total_duration = stream
77 .codec_params
78 .time_base
79 .zip(stream.codec_params.n_frames)
80 .map(|(base, spans)| base.calc_time(spans).into());
81
82 let mut decode_errors: usize = 0;
83 let decoded_audio = loop {
84 let current_span = probed_format.next_packet()?;
85
86 if current_span.track_id() != track_id {
88 continue;
89 }
90
91 match decoder.decode(¤t_span) {
92 Ok(audio) => break audio,
93 Err(Error::DecodeError(_)) if decode_errors < MAX_DECODE_RETRIES => {
94 decode_errors += 1;
95 }
96 Err(e) => return Err(e),
97 }
98 };
99
100 let spec = decoded_audio.spec().to_owned();
101 let buffer = Self::get_buffer(decoded_audio, spec);
102 Ok(Some(Self {
103 decoder,
104 current_span_offset: 0,
105 format: probed_format,
106 total_duration,
107 buffer,
108 spec,
109 }))
110 }
111
112 #[inline]
113 fn get_buffer(decoded: AudioBufferRef<'_>, spec: SignalSpec) -> SampleBuffer<f32> {
114 let duration = units::Duration::from(decoded.capacity() as u64);
115 let mut buffer = SampleBuffer::<f32>::new(duration, spec);
116 buffer.copy_interleaved_ref(decoded);
117 buffer
118 }
119
120 #[inline]
121 #[must_use]
122 pub const fn total_duration(&self) -> Option<Duration> {
123 self.total_duration
124 }
125
126 #[inline]
127 #[must_use]
128 pub const fn sample_rate(&self) -> u32 {
129 self.spec.rate
130 }
131
132 #[inline]
133 #[must_use]
134 pub fn channels(&self) -> usize {
135 self.spec.channels.count()
136 }
137}
138
139impl Iterator for SymphoniaSource {
140 type Item = f32;
141
142 fn size_hint(&self) -> (usize, Option<usize>) {
143 (
144 self.buffer.samples().len(),
145 self.total_duration.map(|dur| {
146 usize::try_from(
147 (dur.as_secs() + 1)
148 * u64::from(self.spec.rate)
149 * self.spec.channels.count() as u64,
150 )
151 .unwrap_or(usize::MAX)
152 }),
153 )
154 }
155
156 fn next(&mut self) -> Option<Self::Item> {
157 if self.current_span_offset < self.buffer.len() {
158 let sample = self.buffer.samples().get(self.current_span_offset);
159 self.current_span_offset += 1;
160
161 return sample.copied();
162 }
163
164 let mut decode_errors = 0;
165 let decoded = loop {
166 let packet = self.format.next_packet().ok()?;
167 match self.decoder.decode(&packet) {
168 Ok(decoded) if decoded.frames() > 0 => break decoded,
174 Ok(_) => {}
175 Err(Error::DecodeError(_)) if decode_errors < MAX_DECODE_RETRIES => {
176 decode_errors += 1;
177 }
178 Err(_) => return None,
179 }
180 };
181
182 decoded.spec().clone_into(&mut self.spec);
183 self.buffer = Self::get_buffer(decoded, self.spec);
184 self.current_span_offset = 1;
185 self.buffer.samples().first().copied()
186 }
187}
188
189#[allow(clippy::module_name_repetitions)]
190pub struct MecompDecoder<R = FastFixedIn<f32>> {
191 resampler: Pool<Result<R, ResamplerConstructionError>>,
192}
193
194impl MecompDecoder {
195 #[inline]
196 fn generate_resampler() -> Result<FastFixedIn<f32>, ResamplerConstructionError> {
197 FastFixedIn::new(1.0, 10.0, rubato::PolynomialDegree::Cubic, CHUNK_SIZE, 1)
198 }
199
200 #[inline]
206 pub fn new() -> Result<Self, AnalysisError> {
207 let first = Self::generate_resampler()?;
209
210 let pool_size = std::thread::available_parallelism().map_or(1, NonZeroUsize::get);
211 let resampler = Pool::new(pool_size, Self::generate_resampler);
212 resampler.attach(Ok(first));
213
214 Ok(Self { resampler })
215 }
216
217 #[inline]
227 #[doc(hidden)]
228 pub fn into_mono_samples(
229 source: Vec<f32>,
230 num_channels: usize,
231 ) -> Result<Vec<f32>, AnalysisError> {
232 match num_channels {
233 0 => Err(AnalysisError::DecodeError(Error::DecodeError(
235 "The audio source has no channels",
236 ))),
237 1 => Ok(source),
239 2 => {
241 let len = source.len() / 2;
242 let mut result = vec![0f32; len];
243 let scale = SQRT_2 * 0.5;
244
245 let (src_chunks, src_remainder) = source.as_chunks::<16>();
247 let (dest_chunks, dest_remainder) = result.as_chunks_mut::<8>();
248
249 for (src, dest) in src_chunks.iter().zip(dest_chunks) {
250 for i in 0..8 {
252 dest[i] = (src[2 * i] + src[2 * i + 1]) * scale;
253 }
254 }
255
256 for (i, chunk) in src_remainder.chunks_exact(2).enumerate() {
258 dest_remainder[i] = (chunk[0] + chunk[1]) * scale;
259 }
260
261 Ok(result)
262 }
263 _ => {
265 log::warn!(
266 "The audio source has more than 2 channels (might be 2.1 or 5.1 surround sound), will collapse to mono by averaging the channels"
267 );
268
269 #[allow(clippy::cast_precision_loss)]
270 let num_channels_f32 = num_channels as f32;
271 let mono_samples = source
272 .chunks_exact(num_channels)
273 .map(|chunk| chunk.iter().sum::<f32>() / num_channels_f32)
274 .collect();
275
276 Ok(mono_samples)
277 }
278 }
279 }
280
281 #[inline]
283 #[doc(hidden)]
284 pub fn resample_mono_samples(
285 &self,
286 mut samples: Vec<f32>,
287 sample_rate: u32,
288 ) -> Result<Vec<f32>, AnalysisError> {
289 if sample_rate == SAMPLE_RATE {
290 samples.shrink_to_fit();
291 return Ok(samples);
292 }
293
294 let resample_ratio = f64::from(SAMPLE_RATE) / f64::from(sample_rate);
295 #[allow(
296 clippy::cast_possible_truncation,
297 clippy::cast_sign_loss,
298 clippy::cast_precision_loss
299 )]
300 let mut resampled_frames = Vec::with_capacity(
301 (samples.len() as f64 * resample_ratio) as usize + SAMPLE_RATE as usize, );
303
304 let (pool, resampler) = self.resampler.pull(Self::generate_resampler).detach();
305 let mut resampler = resampler?;
306 resampler.set_resample_ratio(resample_ratio, false)?;
307
308 let delay = resampler.output_delay();
309
310 let new_length = samples.len() * SAMPLE_RATE as usize / sample_rate as usize;
311 let mut output_buffer = resampler.output_buffer_allocate(true);
312
313 let sample_chunks = samples.chunks_exact(CHUNK_SIZE);
315 let remainder = sample_chunks.remainder();
316
317 for chunk in sample_chunks {
318 debug_assert!(resampler.input_frames_next() == CHUNK_SIZE);
319
320 let (_, output_written) =
321 resampler.process_into_buffer(&[chunk], output_buffer.as_mut_slice(), None)?;
322 resampled_frames.extend_from_slice(&output_buffer[0][..output_written]);
323 }
324
325 if !remainder.is_empty() {
327 let (_, output_written) = resampler.process_partial_into_buffer(
328 Some(&[remainder]),
329 output_buffer.as_mut_slice(),
330 None,
331 )?;
332 resampled_frames.extend_from_slice(&output_buffer[0][..output_written]);
333 }
334
335 if resampled_frames.len() < new_length + delay {
337 let (_, output_written) = resampler.process_partial_into_buffer(
338 Option::<&[&[f32]]>::None,
339 output_buffer.as_mut_slice(),
340 None,
341 )?;
342 resampled_frames.extend_from_slice(&output_buffer[0][..output_written]);
343 }
344
345 resampler.reset();
346 pool.attach(Ok(resampler));
347
348 Ok(resampled_frames[delay..new_length + delay].to_vec())
349 }
350}
351
352impl Decoder for MecompDecoder {
353 #[allow(clippy::missing_inline_in_public_items)]
359 fn decode(&self, path: &std::path::Path) -> AnalysisResult<ResampledAudio> {
360 let file = File::open(path)?;
362 let mss = MediaSourceStream::new(Box::new(file), MediaSourceStreamOptions::default());
364
365 let source = SymphoniaSource::new(mss)?;
366
367 let sample_rate = source.spec.rate;
369 let num_channels = source.channels();
370
371 let mono_sample_array =
372 Self::into_mono_samples(source.into_iter().collect(), num_channels)?;
373
374 let resampled_array = self.resample_mono_samples(mono_sample_array, sample_rate)?;
376
377 Ok(ResampledAudio {
378 path: path.to_owned(),
379 samples: resampled_array,
380 })
381 }
382}
383
384#[cfg(test)]
385mod tests {
386 use crate::{NUMBER_FEATURES, embeddings::ModelConfig};
387
388 use super::{Decoder as DecoderTrait, MecompDecoder as Decoder};
389 use adler32::RollingAdler32;
390 use pretty_assertions::assert_eq;
391 use rstest::rstest;
392 use std::{collections::HashMap, path::Path, sync::mpsc};
393
394 fn verify_decoding_output(path: &Path, expected_hash: u32) {
395 let decoder = Decoder::new().unwrap();
396 let song = decoder.decode(path).unwrap();
397 let mut hasher = RollingAdler32::new();
398 for sample in &song.samples {
399 hasher.update_buffer(&sample.to_le_bytes());
400 }
401
402 assert_eq!(expected_hash, hasher.hash());
403 }
404
405 #[rstest]
408 #[ignore = "fails when asked to resample to 22050 Hz, ig ffmpeg does it differently, but I'm not sure what the difference actually is"]
409 #[case::resample_stereo(Path::new("data/s32_stereo_44_1_kHz.flac"), 0xbbcb_a1cf)]
410 #[ignore = "fails when asked to resample to 22050 Hz, ig ffmpeg does it differently, but I'm not sure what the difference actually is"]
411 #[case::resample_mono(Path::new("data/s32_mono_44_1_kHz.flac"), 0xa0f8_b8af)]
412 #[case::decode_stereo(Path::new("data/s16_stereo_22_5kHz.flac"), 0x1d7b_2d6d)]
413 #[case::decode_mono(Path::new("data/s16_mono_22_5kHz.flac"), 0x5e01_930b)]
414 #[ignore = "fails when asked to resample to 22050 Hz, ig ffmpeg does it differently, but I'm not sure what the difference actually is"]
415 #[case::resample_mp3(Path::new("data/s32_stereo_44_1_kHz.mp3"), 0x69ca_6906)]
416 #[case::decode_wav(Path::new("data/piano.wav"), 0xde83_1e82)]
417 fn test_decode(#[case] path: &Path, #[case] expected_hash: u32) {
418 verify_decoding_output(path, expected_hash);
419 }
420
421 #[test]
422 fn test_dont_panic_no_channel_layout() {
423 let path = Path::new("data/no_channel.wav");
424 Decoder::new().unwrap().decode(path).unwrap();
425 }
426
427 #[test]
428 fn test_decode_right_capacity_vec() {
429 let path = Path::new("data/s16_mono_22_5kHz.flac");
430 let song = Decoder::new().unwrap().decode(path).unwrap();
431 let sample_array = song.samples;
432 assert_eq!(
433 sample_array.len(), sample_array.capacity()
435 );
436
437 let path = Path::new("data/s32_stereo_44_1_kHz.flac");
438 let song = Decoder::new().unwrap().decode(path).unwrap();
439 let sample_array = song.samples;
440 assert_eq!(
441 sample_array.len(), sample_array.capacity()
443 );
444
445 let path = Path::new("data/capacity_fix.wav");
447 let song = Decoder::new().unwrap().decode(path).unwrap();
448 let sample_array = song.samples;
449 assert_eq!(
450 sample_array.len(), sample_array.capacity()
452 );
453 }
454
455 const PATH_AND_EXPECTED_ANALYSIS: (&str, [f32; NUMBER_FEATURES]) = (
456 "data/s16_mono_22_5kHz.flac",
457 [
458 0.384_638_9,
459 -0.849_141_,
460 -0.754_810_45,
461 -0.879_074_8,
462 -0.632_582_66,
463 -0.725_895_9,
464 -0.775_737_9,
465 -0.814_672_6,
466 0.271_672_6,
467 0.257_790_57,
468 -0.342_925_13,
469 -0.628_034_23,
470 -0.280_950_96,
471 0.086_864_59,
472 0.244_460_82,
473 -0.572_325_7,
474 0.232_920_65,
475 0.199_811_46,
476 -0.585_944_06,
477 -0.067_842_96,
478 -0.060_007_63,
479 -0.584_857_17,
480 -0.078_803_78,
481 ],
482 );
483
484 #[test]
485 fn test_analyze() {
486 let (path, expected_analysis) = PATH_AND_EXPECTED_ANALYSIS;
487 let analysis = Decoder::new()
488 .unwrap()
489 .analyze_path(Path::new(path))
490 .unwrap();
491 for (x, y) in analysis.as_vec().iter().zip(expected_analysis) {
492 assert!(
493 1e-5 > (x - y).abs(),
494 "Expected {x} to be within 1e-5 of {y}, but it was not"
495 );
496 }
497 }
498
499 const RESAMPLED_PATH_AND_EXPECTED_ANALYSIS: (&str, [f32; NUMBER_FEATURES]) = (
500 "data/s32_stereo_44_1_kHz.flac",
501 [
502 0.38463664,
503 -0.85172224,
504 -0.7607465,
505 -0.8857495,
506 -0.63906085,
507 -0.73908424,
508 -0.7890965,
509 -0.8191868,
510 0.33856833,
511 0.3246863,
512 -0.34292227,
513 -0.62803173,
514 -0.2809453,
515 0.08687115,
516 0.2444489,
517 -0.5723239,
518 0.23292565,
519 0.19979525,
520 -0.58593845,
521 -0.06783122,
522 -0.060014784,
523 -0.5848569,
524 -0.07879859,
525 ],
526 );
527
528 #[test]
529 fn test_analyze_resampled() {
530 let (path, expected_analysis) = RESAMPLED_PATH_AND_EXPECTED_ANALYSIS;
531 let analysis = Decoder::new()
532 .unwrap()
533 .analyze_path(Path::new(path))
534 .unwrap();
535
536 for (x, y) in analysis.as_vec().iter().zip(expected_analysis) {
537 assert!(
538 0.1 > (x - y).abs(),
539 "Expected {x} to be within 0.1 of {y}, but it was not"
540 );
541 }
542 }
543
544 #[test]
545 fn test_analyze_paths() {
546 let paths = Path::new(env!("CARGO_MANIFEST_DIR"))
548 .join("data")
549 .read_dir()
550 .unwrap()
551 .map(|entry| entry.unwrap().path())
552 .filter(|p| {
553 p.is_file()
554 && (p.extension().unwrap() == "wav"
555 || p.extension().unwrap() == "flac"
556 || p.extension().unwrap() == "mp3")
557 })
558 .collect::<Vec<_>>();
559
560 let mut analyzed_paths = HashMap::new();
562 for path in &paths {
563 analyzed_paths.insert(path.clone(), false);
564 }
565
566 let decoder = Decoder::new().unwrap();
568 let mut count = 0;
569 let expected = paths.len();
570 let (tx, rx) = mpsc::channel();
571 let handle = std::thread::spawn(move || decoder.analyze_paths(&paths, tx));
572 for (path, analysis) in rx {
573 count += 1;
574 assert!(analysis.is_ok(), "Failed to analyze {path:?}: {analysis:?}",);
575 assert_eq!(
576 analyzed_paths.insert(path.clone(), true),
577 Some(false),
578 "Analyzed the same path twice: {path:?}"
579 );
580 }
581
582 assert_eq!(count, expected);
583 assert!(
584 analyzed_paths.values().all(|&v| v),
585 "Not all paths were analyzed: {analyzed_paths:?}"
586 );
587
588 handle.join().unwrap().unwrap();
589 }
590
591 #[test]
592 fn test_process_paths() {
593 let paths = Path::new(env!("CARGO_MANIFEST_DIR"))
595 .join("data")
596 .read_dir()
597 .unwrap()
598 .map(|entry| entry.unwrap().path())
599 .filter(|p| {
600 p.is_file()
601 && (p.extension().unwrap() == "wav"
602 || p.extension().unwrap() == "flac"
603 || p.extension().unwrap() == "mp3")
604 })
605 .collect::<Vec<_>>();
606
607 let mut analyzed_paths = HashMap::new();
609 for path in &paths {
610 analyzed_paths.insert(path.clone(), false);
611 }
612
613 let decoder = Decoder::new().unwrap();
615 let model_config = ModelConfig::default();
616 let (tx, rx) = std::sync::mpsc::sync_channel(4);
617
618 let paths_clone = paths.clone();
620 std::thread::spawn(move || decoder.process_songs(&paths_clone, tx, model_config));
621
622 let mut count = 0;
623 for (path, analysis, embedding) in rx {
624 count += 1;
625 assert!(analysis.is_ok(), "Failed to analyze {path:?}: {analysis:?}");
626 assert!(embedding.is_ok(), "Failed to embed {path:?}: {embedding:?}");
627 assert_eq!(
628 analyzed_paths.insert(path.clone(), true),
629 Some(false),
630 "Analyzed the same path twice: {path:?}"
631 );
632 }
633 assert_eq!(count, paths.len());
634 assert!(
635 analyzed_paths.values().all(|&v| v),
636 "Not all paths were analyzed: {analyzed_paths:?}"
637 );
638 }
639}