1use crate::Feature;
7
8use super::errors::{AnalysisError, AnalysisResult};
9use super::utils::{Normalize, hz_to_octs_inplace, stft};
10use bitvec::vec::BitVec;
11use likely_stable::{LikelyResult, likely, unlikely};
12use ndarray::{Array, Array1, Array2, Axis, Order, Zip, arr2, concatenate, s};
13use ndarray_stats::QuantileExt;
14use noisy_float::prelude::*;
15
16#[derive(Debug, Clone)]
27#[allow(clippy::module_name_repetitions)]
28pub struct ChromaDesc {
29 sample_rate: u32,
30 n_chroma: u32,
31 values_chroma: Array2<f32>,
32}
33
34impl Normalize for ChromaDesc {
35 const MAX_VALUE: Feature = 1.0;
36 const MIN_VALUE: Feature = 0.;
37}
38
39impl ChromaDesc {
40 pub const WINDOW_SIZE: usize = 8192;
41 pub const MAX_L2_INTERVAL: f32 = 0.25;
49 pub const MAX_L2_TRIAD: f32 = 0.025;
57 pub const MAX_TRIAD_INTERVAL_RATIO: f32 = std::f32::consts::FRAC_PI_2;
59 #[must_use]
60 #[inline]
61 pub fn new(sample_rate: u32, n_chroma: u32) -> Self {
62 Self {
63 sample_rate,
64 n_chroma,
65 values_chroma: Array2::zeros((n_chroma as usize, 0)),
66 }
67 }
68
69 #[allow(clippy::missing_errors_doc, clippy::missing_panics_doc)]
76 #[inline]
77 pub fn do_(&mut self, signal: &[f32]) -> AnalysisResult<()> {
78 let stft = stft(signal, Self::WINDOW_SIZE, 2205);
79 let tuning = estimate_tuning(self.sample_rate, &stft, Self::WINDOW_SIZE, 0.01, 12)?;
80 let chroma = chroma_stft(
81 self.sample_rate,
82 &stft,
83 Self::WINDOW_SIZE,
84 self.n_chroma,
85 tuning,
86 )?;
87 self.values_chroma = concatenate![Axis(1), self.values_chroma, chroma];
88 Ok(())
89 }
90
91 #[inline]
102 pub fn get_value(&mut self) -> Vec<Feature> {
103 let mut raw_features = chroma_interval_features(&self.values_chroma);
104 let (mut interval_class, mut interval_class_mode) =
105 raw_features.view_mut().split_at(Axis(0), 6);
106 let l2_norm_interval_class = interval_class.dot(&interval_class).sqrt();
109 let l2_norm_interval_class_mode = interval_class_mode.dot(&interval_class_mode).sqrt();
110 if l2_norm_interval_class > 0. {
111 interval_class /= l2_norm_interval_class;
112 }
113 if l2_norm_interval_class_mode > 0. {
114 interval_class_mode /= l2_norm_interval_class_mode;
115 }
116 let mut features = raw_features.mapv_into_any(|x| self.normalize(x)).to_vec();
117
118 let normalized_l2_norm_interval_class =
119 (2. * l2_norm_interval_class / Self::MAX_L2_INTERVAL - 1.).min(1.);
120 features.push(normalized_l2_norm_interval_class);
121 let normalized_l2_norm_interval_class_mode =
122 (2. * l2_norm_interval_class_mode / Self::MAX_L2_TRIAD - 1.).min(1.);
123 features.push(normalized_l2_norm_interval_class_mode);
124 let angle = (20. * l2_norm_interval_class_mode).atan2(l2_norm_interval_class + 1e-12_f32);
125 let normalized_ratio = 2. * angle / Self::MAX_TRIAD_INTERVAL_RATIO - 1.;
126 features.push(normalized_ratio);
127 features
128 }
129}
130
131#[allow(
134 clippy::missing_errors_doc,
135 clippy::missing_panics_doc,
136 clippy::module_name_repetitions
137)]
138#[must_use]
139#[inline]
140pub fn chroma_interval_features(chroma: &Array2<f32>) -> Array1<f32> {
141 let chroma = normalize_feature_sequence(&(chroma * 15.).exp());
142 let templates = arr2(&[
143 [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
144 [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
145 [0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
146 [0, 0, 1, 0, 0, 0, 0, 1, 1, 0],
147 [0, 0, 0, 1, 0, 0, 1, 0, 0, 1],
148 [0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
149 [0, 0, 0, 0, 0, 1, 0, 0, 1, 0],
150 [0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
151 [0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
152 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
153 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
154 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
155 ]);
156 let interval_feature_matrix = extract_interval_features(&chroma, &templates);
157 interval_feature_matrix.mean_axis(Axis(1)).unwrap()
158}
159
160#[must_use]
161#[inline]
162pub fn extract_interval_features(chroma: &Array2<f32>, templates: &Array2<i32>) -> Array2<f32> {
163 let n_templates = templates.shape()[1];
164 let n_chroma = chroma.shape()[0]; let n_cols = chroma.shape()[1];
166
167 let chroma_t = chroma.t();
168 let mut f_intervals: Array2<f32> = Array::zeros((n_templates, n_cols));
169
170 let active_indices: Vec<Vec<usize>> = templates
172 .axis_iter(Axis(1))
173 .map(|t| {
174 t.iter()
175 .enumerate()
176 .filter_map(|(i, &v)| (v == 1).then_some(i))
177 .collect()
178 })
179 .collect();
180
181 for (col_idx, col) in chroma_t.rows().into_iter().enumerate() {
183 for (tmpl_idx, indices) in active_indices.iter().enumerate() {
184 let sum = (0..n_chroma)
185 .map(|shift| {
186 indices
187 .iter()
188 .map(|&idx| col[(idx + shift) % n_chroma])
189 .product::<f32>()
190 })
191 .sum();
192 f_intervals[(tmpl_idx, col_idx)] = sum;
193 }
194 }
195
196 f_intervals
197}
198
199#[inline]
200pub fn normalize_feature_sequence(feature: &Array2<f32>) -> Array2<f32> {
201 let mut normalized_sequence = feature.to_owned();
202 for mut column in normalized_sequence.columns_mut() {
203 let sum: f32 = column.iter().copied().map(f32::abs).sum();
204 if likely(sum >= 0.0001) {
205 column /= sum;
206 }
207 }
208
209 normalized_sequence
210}
211
212#[allow(
220 clippy::missing_errors_doc,
221 clippy::missing_panics_doc,
222 clippy::module_name_repetitions,
223 clippy::missing_inline_in_public_items
224)]
225pub fn chroma_filter(
226 sample_rate: u32,
227 n_fft: usize,
228 n_chroma: u32,
229 tuning: f32,
230) -> AnalysisResult<Array2<f32>> {
231 let ctroct = 5.0;
232 let octwidth = 2.;
233 #[allow(clippy::cast_precision_loss)]
234 let n_chroma2 = (n_chroma >> 1) as f32;
235 #[allow(clippy::cast_precision_loss)]
236 let n_chroma_float = n_chroma as f32;
237
238 #[allow(clippy::cast_precision_loss)]
239 let frequencies = Array::linspace(0., sample_rate as f32, n_fft + 1);
240
241 let mut freq_bins = frequencies;
242 hz_to_octs_inplace(&mut freq_bins, tuning, n_chroma);
243 freq_bins *= n_chroma_float;
244 freq_bins[0] = (1.5).mul_add(-n_chroma_float, freq_bins[1]);
245 let mut binwidth_bins = Array::ones(freq_bins.raw_dim());
246 binwidth_bins
247 .slice_mut(s![0..freq_bins.len() - 1])
248 .assign(&(&freq_bins.slice(s![1..]) - &freq_bins.slice(s![..-1])).mapv(|x| x.max(1.)));
249
250 let mut d: Array2<f32> = Array::zeros((n_chroma as usize, (freq_bins).len()));
251 for (idx, mut row) in d.rows_mut().into_iter().enumerate() {
252 #[allow(clippy::cast_precision_loss)]
253 row.fill(idx as f32);
254 }
255
256 d.zip_mut_with(&freq_bins, |d_elem, &fb| {
257 let x = -*d_elem + fb;
258 let x = n_chroma_float.mul_add(10., x + n_chroma2);
259 *d_elem = x % n_chroma_float - n_chroma2;
260 });
261 d.zip_mut_with(&binwidth_bins, |d_elem, &bb| {
262 let x = *d_elem / bb;
263 *d_elem = (-2. * x * x).exp();
264 });
265
266 let mut wts = d;
267 for mut col in wts.columns_mut() {
269 let sum = col.pow2().sum().sqrt();
270 if sum >= f32::MIN_POSITIVE {
271 col /= sum;
272 }
273 }
274
275 freq_bins = (-0.5 * ((freq_bins / n_chroma_float - ctroct) / octwidth).powi(2)).exp();
277
278 wts *= &freq_bins;
279
280 let mut b = Array2::zeros(wts.dim());
282 b.slice_mut(s![-3.., ..]).assign(&wts.slice(s![..3, ..]));
283 b.slice_mut(s![..-3, ..]).assign(&wts.slice(s![3.., ..]));
284
285 wts = b;
286 let non_aliased = 1 + n_fft / 2;
287 Ok(wts.slice_move(s![.., ..non_aliased]))
288}
289
290#[allow(clippy::missing_errors_doc, clippy::missing_panics_doc)]
291#[allow(clippy::missing_inline_in_public_items)]
292pub fn pip_track(
293 sample_rate: u32,
294 spectrum: &Array2<f32>,
295 n_fft: usize,
296) -> AnalysisResult<(Vec<f32>, Vec<f32>)> {
297 #[allow(clippy::cast_precision_loss)]
298 let sample_rate_float = sample_rate as f32;
299 let fmin = 150.0;
300 let fmax = 4000.0.min(sample_rate_float / 2.0);
301 let threshold = 0.1;
302
303 let fft_freqs = Array::linspace(0., sample_rate_float / 2., 1 + n_fft / 2);
304
305 let length = spectrum.shape()[1];
306
307 let freq_mask = fft_freqs
308 .iter()
309 .map(|&f| (fmin <= f) && (f < fmax))
310 .collect::<BitVec>();
311
312 let ref_value = spectrum.map_axis(Axis(0), |x| {
313 let first = *x.first().expect("empty spectrum axis");
314 let max = x.fold(first, |acc, &elem| acc.max(elem));
315 threshold * max
316 });
317
318 let freq_mask_len = freq_mask.len();
320 let (taken_columns, beginning, end) = freq_mask.iter().enumerate().fold(
321 (0, freq_mask_len, 0),
322 |(taken, beginning, end), (i, b)| {
323 b.then(|| (taken + 1, beginning.min(i), end.max(i + 1)))
324 .unwrap_or((taken, beginning, end))
325 },
326 );
327
328 if beginning >= end {
330 return Err(AnalysisError::AnalysisError(String::from(
331 "in chroma: no valid frequency range found",
332 )));
333 }
334 let mut pitches = Vec::with_capacity(taken_columns * length);
336 let mut mags = Vec::with_capacity(taken_columns * length);
337
338 let zipped = Zip::indexed(spectrum.slice(s![beginning..end - 3, ..]))
339 .and(spectrum.slice(s![beginning + 1..end - 2, ..]))
340 .and(spectrum.slice(s![beginning + 2..end - 1, ..]));
341
342 zipped.for_each(|(i, j), &before_elem, &elem, &after_elem| {
345 if elem > ref_value[j] && after_elem <= elem && before_elem < elem {
346 let avg = 0.5 * (after_elem - before_elem);
347 let mut shift = (2.0).mul_add(elem, -after_elem - before_elem);
348 if shift.abs() < f32::MIN_POSITIVE {
349 shift += 1.;
350 }
351 shift = avg / shift;
352 #[allow(clippy::cast_precision_loss)]
353 pitches.push(((i + beginning + 1) as f32 + shift) * sample_rate_float / n_fft as f32);
354 mags.push((0.5 * avg).mul_add(shift, elem));
355 }
356 });
357
358 Ok((pitches, mags))
359}
360
361#[allow(clippy::missing_errors_doc, clippy::missing_panics_doc)]
363#[inline]
364pub fn pitch_tuning(
365 mut frequencies: Array1<f32>,
366 resolution: f32,
367 bins_per_octave: u32,
368) -> AnalysisResult<f32> {
369 if unlikely(frequencies.is_empty()) {
370 return Ok(0.0);
371 }
372 hz_to_octs_inplace(&mut frequencies, 0.0, 12);
373 #[allow(clippy::cast_precision_loss)]
374 frequencies.mapv_inplace(|x| (bins_per_octave as f32 * x).fract());
375
376 frequencies.mapv_inplace(|x| if x >= 0.5 { x - 0.5 } else { x + 0.5 });
378
379 #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
380 let indexes = (frequencies / resolution).mapv(|x| x as usize);
381 #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
382 let mut counts: Array1<usize> = Array::zeros(resolution.recip() as usize);
383 for &idx in &indexes {
384 counts[idx] += 1;
385 }
386 let max_index = counts
387 .argmax()
388 .map_err_unlikely(|e| AnalysisError::AnalysisError(format!("in chroma: {e}")))?;
389
390 #[allow(clippy::cast_precision_loss)]
392 Ok((100. * resolution).mul_add(max_index as f32, -50.) / 100.)
393}
394
395#[allow(clippy::missing_errors_doc, clippy::missing_panics_doc)]
396#[inline]
397pub fn estimate_tuning(
398 sample_rate: u32,
399 spectrum: &Array2<f32>,
400 n_fft: usize,
401 resolution: f32,
402 bins_per_octave: u32,
403) -> AnalysisResult<f32> {
404 let (pitch, mag) = pip_track(sample_rate, spectrum, n_fft)?;
405
406 if unlikely(pitch.is_empty()) {
407 return Ok(0.);
408 }
409
410 let (filtered_pitch, filtered_mag): (Vec<N32>, Vec<N32>) = pitch
411 .iter()
412 .zip(&mag)
413 .filter(|&(&p, _)| p > 0.)
414 .map(|(x, y)| (n32(*x), n32(*y)))
415 .unzip();
416
417 let mut mag_copy = filtered_mag.clone();
418 let mid = mag_copy.len() / 2;
419 let threshold = *mag_copy
420 .select_nth_unstable_by(mid, |a, b| a.partial_cmp(b).unwrap())
421 .1;
422
423 let pitch = filtered_pitch
424 .iter()
425 .zip(&filtered_mag)
426 .filter_map(
427 |(&p, &m)| {
428 if m >= threshold { Some(p.into()) } else { None }
429 },
430 )
431 .collect::<Array1<f32>>();
432 pitch_tuning(pitch, resolution, bins_per_octave)
433}
434
435#[allow(
436 clippy::missing_errors_doc,
437 clippy::missing_panics_doc,
438 clippy::module_name_repetitions
439)]
440#[inline]
441pub fn chroma_stft(
442 sample_rate: u32,
443 spectrum: &Array2<f32>, n_fft: usize,
445 n_chroma: u32,
446 tuning: f32,
447) -> AnalysisResult<Array2<f32>> {
448 let mut raw_chroma = chroma_filter(sample_rate, n_fft, n_chroma, tuning)?;
449
450 raw_chroma = raw_chroma.dot(&spectrum.pow2());
451
452 raw_chroma = raw_chroma
455 .to_shape((raw_chroma.dim(), Order::ColumnMajor))
456 .map_err_unlikely(|_| {
457 AnalysisError::AnalysisError(String::from("in chroma: failed to reorder array"))
458 })?
459 .to_owned();
460
461 Zip::from(raw_chroma.columns_mut()).for_each(|mut row| {
462 let sum = row.sum(); if sum >= f32::MIN_POSITIVE {
464 row /= sum;
465 }
466 });
467
468 Ok(raw_chroma)
469}
470
471#[cfg(test)]
472mod test {
473 use super::*;
474 use crate::{
475 SAMPLE_RATE,
476 decoder::{Decoder as _, MecompDecoder as Decoder},
477 utils::stft,
478 };
479 use ndarray::{Array2, arr1, arr2};
480 use ndarray_npy::ReadNpyExt as _;
481 use std::{fs::File, path::Path};
482
483 #[test]
484 fn test_chroma_interval_features() {
485 let file = File::open("data/chroma.npy").unwrap();
486 let chroma = Array2::<f64>::read_npy(file).unwrap().mapv(|x| x as f32);
487 let features = chroma_interval_features(&chroma);
488 let expected_features = arr1(&[
489 0.038_602_84,
490 0.021_852_81,
491 0.042_243_79,
492 0.063_852_78,
493 0.073_111_48,
494 0.025_125_66,
495 0.003_198_99,
496 0.003_113_08,
497 0.001_074_33,
498 0.002_418_61,
499 ]);
500 for (expected, actual) in expected_features.iter().zip(&features) {
501 assert!(
503 0.000_000_1 > (expected - actual.abs()),
504 "{expected} !~= {actual}"
505 );
506 }
507 }
508
509 #[test]
510 fn test_extract_interval_features() {
511 let file = File::open("data/chroma-interval.npy").unwrap();
512 let chroma = Array2::<f64>::read_npy(file).unwrap().mapv(|x| x as f32);
513 let templates = arr2(&[
514 [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
515 [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
516 [0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
517 [0, 0, 1, 0, 0, 0, 0, 1, 1, 0],
518 [0, 0, 0, 1, 0, 0, 1, 0, 0, 1],
519 [0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
520 [0, 0, 0, 0, 0, 1, 0, 0, 1, 0],
521 [0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
522 [0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
523 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
524 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
525 [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
526 ]);
527
528 let file = File::open("data/interval-feature-matrix.npy").unwrap();
529 let expected_interval_features = Array2::<f64>::read_npy(file).unwrap().mapv(|x| x as f32);
530
531 let interval_features = extract_interval_features(&chroma, &templates);
532 for (expected, actual) in expected_interval_features
533 .iter()
534 .zip(interval_features.iter())
535 {
536 assert!(
537 0.000_000_1 > (expected - actual).abs(),
538 "{expected} !~= {actual}"
539 );
540 }
541 }
542
543 #[test]
544 fn test_normalize_feature_sequence() {
545 let array = arr2(&[[0.1, 0.3, 0.4], [1.1, 0.53, 1.01]]);
546 let expected_array = arr2(&[
547 [0.083_333_33, 0.361_445_78, 0.283_687_94],
548 [0.916_666_67, 0.638_554_22, 0.716_312_06],
549 ]);
550
551 let normalized_array = normalize_feature_sequence(&array);
552
553 assert!(!array.is_empty() && !expected_array.is_empty());
554
555 for (expected, actual) in normalized_array.iter().zip(expected_array.iter()) {
556 assert!(
557 0.000_000_1 > (expected - actual).abs(),
558 "{expected} !~= {actual}"
559 );
560 }
561 }
562
563 #[test]
564 fn test_chroma_desc() {
565 let song = Decoder::new()
566 .unwrap()
567 .decode(Path::new("data/s16_mono_22_5kHz.flac"))
568 .unwrap();
569 let mut chroma_desc = ChromaDesc::new(SAMPLE_RATE, 12);
570 chroma_desc.do_(&song.samples).unwrap();
571 let expected_values = [
572 -0.342_925_13,
573 -0.628_034_23,
574 -0.280_950_96,
575 0.086_864_59,
576 0.244_460_82,
577 -0.572_325_7,
578 0.232_920_65,
579 0.199_811_46,
580 -0.585_944_06,
581 -0.067_842_96,
582 ];
583 for (expected, actual) in expected_values.iter().zip(chroma_desc.get_value().iter()) {
584 let relative_error = (expected - actual).abs() / expected.abs();
586 assert!(
587 relative_error < 0.01,
588 "relative error: {relative_error}, expected: {expected}, actual: {actual}"
589 );
590 }
591 }
592
593 #[test]
594 fn test_chroma_stft_decode() {
595 let signal = Decoder::new()
596 .unwrap()
597 .decode(Path::new("data/s16_mono_22_5kHz.flac"))
598 .unwrap()
599 .samples;
600 let stft = stft(&signal, 8192, 2205);
601
602 let file = File::open("data/chroma.npy").unwrap();
603 let expected_chroma = Array2::<f64>::read_npy(file).unwrap().mapv(|x| x as f32);
604
605 let chroma = chroma_stft(22050, &stft, 8192, 12, -0.049_999_999_999_999_99).unwrap();
606
607 assert!(!chroma.is_empty() && !expected_chroma.is_empty());
608
609 for (expected, actual) in expected_chroma.iter().zip(chroma.iter()) {
610 let relative_error = (expected - actual).abs() / expected.abs();
612 assert!(
613 relative_error < 0.01,
614 "relative error: {relative_error}, expected: {expected}, actual: {actual}"
615 );
616 }
617 }
618
619 #[test]
620 fn test_estimate_tuning() {
621 let file = File::open("data/spectrum-chroma.npy").unwrap();
622 let arr = Array2::<f64>::read_npy(file).unwrap().mapv(|x| x as f32);
623
624 let tuning = estimate_tuning(22050, &arr, 2048, 0.01, 12).unwrap();
625 assert!(
626 0.000_001 > (-0.099_999_999_999_999_98 - tuning).abs(),
627 "{tuning} !~= -0.09999999999999998"
628 );
629 }
630
631 #[test]
632 fn test_chroma_estimate_tuning_empty_fix() {
633 assert!(0. == estimate_tuning(22050, &Array2::zeros((8192, 1)), 8192, 0.01, 12).unwrap());
634 }
635
636 #[test]
637 fn test_estimate_tuning_decode() {
638 let signal = Decoder::new()
639 .unwrap()
640 .decode(Path::new("data/s16_mono_22_5kHz.flac"))
641 .unwrap()
642 .samples;
643 let stft = stft(&signal, 8192, 2205);
644
645 let tuning = estimate_tuning(22050, &stft, 8192, 0.01, 12).unwrap();
646 assert!(
647 0.000_001 > (-0.049_999_999_999_999_99 - tuning).abs(),
648 "{tuning} !~= -0.04999999999999999"
649 );
650 }
651
652 #[test]
653 fn test_pitch_tuning() {
654 let file = File::open("data/pitch-tuning.npy").unwrap();
655 let pitch = Array1::<f64>::read_npy(file).unwrap();
656 #[allow(clippy::cast_possible_truncation)]
657 let pitch = pitch.mapv(|x| x as f32);
658
659 let tuned = pitch_tuning(pitch, 0.05, 12).unwrap();
660 assert!(f32::EPSILON > (tuned + 0.1).abs(), "{tuned} != -0.1");
661 }
662
663 #[test]
664 fn test_pitch_tuning_no_frequencies() {
665 let frequencies = arr1(&[]);
666 let tuned = pitch_tuning(frequencies, 0.05, 12).unwrap();
667 assert!(f32::EPSILON > tuned.abs(), "{tuned} != 0");
668 }
669
670 #[test]
671 fn test_pip_track() {
672 let file = File::open("data/spectrum-chroma.npy").unwrap();
673 let spectrum = Array2::<f64>::read_npy(file).unwrap().mapv(|x| x as f32);
674
675 let mags_file = File::open("data/spectrum-chroma-mags.npy").unwrap();
676 let expected_mags = Array1::<f64>::read_npy(mags_file)
677 .unwrap()
678 .mapv(|x| x as f32);
679
680 let pitches_file = File::open("data/spectrum-chroma-pitches.npy").unwrap();
681 let expected_pitches = Array1::<f64>::read_npy(pitches_file)
682 .unwrap()
683 .mapv(|x| x as f32);
684
685 let (mut pitches, mut mags) = pip_track(22050, &spectrum, 2048).unwrap();
686 pitches.sort_by(|a, b| a.partial_cmp(b).unwrap());
687 mags.sort_by(|a, b| a.partial_cmp(b).unwrap());
688
689 for (expected_pitches, actual_pitches) in expected_pitches.iter().zip(pitches.iter()) {
690 assert!(
692 0.001 > (expected_pitches - actual_pitches).abs(),
693 "{expected_pitches} !~= {actual_pitches}"
694 );
695 }
696 for (expected_mags, actual_mags) in expected_mags.iter().zip(mags.iter()) {
697 assert!(
699 0.001 > (expected_mags - actual_mags).abs(),
700 "{expected_mags} !~= {actual_mags}"
701 );
702 }
703 }
704
705 #[test]
706 fn test_chroma_filter() {
707 let file = File::open("data/chroma-filter.npy").unwrap();
708 let expected_filter = Array2::<f64>::read_npy(file).unwrap().mapv(|x| x as f32);
709
710 let filter = chroma_filter(22050, 2048, 12, -0.1).unwrap();
711
712 assert!(filter.iter().all(|&x| x > 0.));
713
714 for (expected, actual) in expected_filter.iter().zip(filter.iter()) {
715 assert!(
717 0.000_1 > (expected - actual).abs(),
718 "{expected} !~= {actual}"
719 );
720 }
721 }
722
723 #[rstest::rstest]
724 #[case::major_triad("data/chroma/Cmaj.ogg", 6)]
726 #[case::major_triad("data/chroma/Dmaj.ogg", 6)]
727 #[case::minor_triad("data/chroma/Cmin.ogg", 7)]
728 #[case::diminished_triad("data/chroma/Cdim.ogg", 8)]
729 #[case::augmented_triad("data/chroma/Caug.ogg", 9)]
730 fn test_end_result_triads(
731 #[case] path: &str,
732 #[case] expected_dominant_chroma_feature_index: usize,
733 ) {
734 let song = Decoder::new().unwrap().decode(Path::new(path)).unwrap();
735 let mut chroma_desc = ChromaDesc::new(SAMPLE_RATE, 12);
736 chroma_desc.do_(&song.samples).unwrap();
737 let chroma_values = chroma_desc.get_value();
738
739 let mut indices: Vec<usize> = (0..chroma_values.len()).collect();
740 indices.sort_by(|&i, &j| chroma_values[j].partial_cmp(&chroma_values[i]).unwrap());
741 assert!(indices[0] == expected_dominant_chroma_feature_index);
742 for (i, v) in chroma_values.into_iter().enumerate() {
743 if i >= 6 && i <= 10 {
744 if i == expected_dominant_chroma_feature_index {
745 assert!(v > 0.8);
746 } else {
747 assert!(v < 0.0);
748 }
749 }
750 }
751 }
752
753 #[test]
754 fn test_end_l2_norm_dyad() {
755 let song = Decoder::new()
756 .unwrap()
757 .decode(Path::new("data/chroma/dyad_tritone_IC6.ogg"))
758 .unwrap();
759 let mut chroma_desc = ChromaDesc::new(SAMPLE_RATE, 12);
760 chroma_desc.do_(&song.samples).unwrap();
761 let chroma_values = chroma_desc.get_value();
762 assert!(chroma_values[10] > 0.9);
763 }
764
765 #[test]
766 fn test_end_l2_norm_mode() {
767 let song = Decoder::new()
768 .unwrap()
769 .decode(Path::new("data/chroma/Cmaj_triads.ogg"))
770 .unwrap();
771 let mut chroma_desc = ChromaDesc::new(SAMPLE_RATE, 12);
772 chroma_desc.do_(&song.samples).unwrap();
773 let chroma_values = chroma_desc.get_value();
774 assert!(chroma_values[11] > 0.9);
775 }
776
777 #[test]
778 fn test_end_l2_norm_ratio() {
779 let song = Decoder::new()
780 .unwrap()
781 .decode(Path::new("data/chroma/triad_aug_maximize_ratio.ogg"))
782 .unwrap();
783 let mut chroma_desc = ChromaDesc::new(SAMPLE_RATE, 12);
784 chroma_desc.do_(&song.samples).unwrap();
785 let chroma_values = chroma_desc.get_value();
786 assert!(chroma_values[12] > 0.7);
787 }
788
789 #[rstest::rstest]
790 #[case::minor_second("data/chroma/minor_second.ogg", 0)]
792 #[case::major_second("data/chroma/major_second.ogg", 1)]
793 #[case::minor_third("data/chroma/minor_third.ogg", 2)]
794 #[case::major_third("data/chroma/major_third.ogg", 3)]
795 #[case::perfect_fourth("data/chroma/perfect_fourth.ogg", 4)]
796 #[case::tritone("data/chroma/tritone.ogg", 5)]
797 #[case::perfect_fifth("data/chroma/perfect_fifth.ogg", 4)]
798 #[case::minor_sixth("data/chroma/minor_sixth.ogg", 3)]
799 #[case::major_sixth("data/chroma/major_sixth.ogg", 2)]
800 #[case::minor_seventh("data/chroma/minor_seventh.ogg", 1)]
801 #[case::major_seventh("data/chroma/major_seventh.ogg", 0)]
802 fn test_end_result_intervals(
803 #[case] path: &str,
804 #[case] expected_dominant_chroma_feature_index: usize,
805 ) {
806 let song = Decoder::new().unwrap().decode(Path::new(path)).unwrap();
807 let mut chroma_desc = ChromaDesc::new(SAMPLE_RATE, 12);
808 chroma_desc.do_(&song.samples).unwrap();
809 let chroma_values = chroma_desc.get_value();
810
811 let mut indices: Vec<usize> = (0..chroma_values.len()).collect();
812 indices.sort_by(|&i, &j| chroma_values[j].partial_cmp(&chroma_values[i]).unwrap());
813 assert_eq!(indices[0], expected_dominant_chroma_feature_index);
814 for (i, v) in chroma_values.into_iter().enumerate() {
815 if i < 6 {
816 if i == expected_dominant_chroma_feature_index {
817 assert!(v > 0.9);
818 } else {
819 assert!(v < 0.0);
820 }
821 }
822 }
823 }
824}