mecomp_analysis/
chroma.rs

1//! Chroma feature extraction module.
2//!
3//! Contains functions to compute the chromagram of a song, and
4//! then from this chromagram extract the song's tone and mode
5//! (minor / major).
6use crate::Feature;
7
8use super::errors::{AnalysisError, AnalysisResult};
9use super::utils::{Normalize, hz_to_octs_inplace, stft};
10use bitvec::vec::BitVec;
11use likely_stable::{LikelyResult, likely, unlikely};
12use ndarray::{Array, Array1, Array2, Axis, Order, Zip, arr1, arr2, concatenate, s};
13use ndarray_stats::QuantileExt;
14use ndarray_stats::interpolate::Midpoint;
15use noisy_float::prelude::*;
16
17/**
18 * General object holding the chroma descriptor.
19 *
20 * Current chroma descriptors are interval features (see
21 * <https://speech.di.uoa.gr/ICMC-SMC-2014/images/VOL_2/1461.pdf>).
22 *
23 * Contrary to the other descriptors that can be used with streaming
24 * without consequences, this one performs better if the full song is used at
25 * once.
26 */
27#[derive(Debug, Clone)]
28#[allow(clippy::module_name_repetitions)]
29pub struct ChromaDesc {
30    sample_rate: u32,
31    n_chroma: u32,
32    values_chroma: Array2<f64>,
33}
34
35impl Normalize for ChromaDesc {
36    const MAX_VALUE: Feature = 1.0;
37    const MIN_VALUE: Feature = 0.;
38}
39
40impl ChromaDesc {
41    pub const WINDOW_SIZE: usize = 8192;
42    /// The theoretical maximum value for IC1-6 is each value at (1/2)².
43    /// The reason is that `extract_interval_features` computes the product of the
44    /// L1-normalized chroma vector (so all of its values are <= 1) by itself.
45    /// The maximum value of this is all coordinates to 1/2 (since dyads will
46    /// select three values). The maximum of this is then (1/2)², so the maximum of its
47    /// L2 norm is this sqrt(2 * (1/2)²) ~= 0.62. However, real-life simulations shown
48    /// that 0.25 is a good ceiling value (see tests).
49    pub const MAX_L2_INTERVAL: f64 = 0.25;
50    /// The theoretical maximum value for IC7-10 is each value at (1/3)³.
51    /// The reason is that `extract_interval_features` computes the product of the
52    /// L1-normalized chroma vector (so all of its values are <= 1) by itself.
53    /// The maximum value of this is all coordinates to 1/3 (since triads will
54    /// select three values). The maximum of this is then (1/3)³, so the maximum of its
55    /// L2 norm is this sqrt(4 * (1/3)³) ~= 0.074. However, real-life simulations shown
56    /// that 0.025 is a good ceiling value (see tests).
57    pub const MAX_L2_TRIAD: f64 = 0.025;
58    /// We are using atan2 to keep the ratio bounded.
59    pub const MAX_TRIAD_INTERVAL_RATIO: f64 = std::f64::consts::FRAC_PI_2;
60    #[must_use]
61    #[inline]
62    pub fn new(sample_rate: u32, n_chroma: u32) -> Self {
63        Self {
64            sample_rate,
65            n_chroma,
66            values_chroma: Array2::zeros((n_chroma as usize, 0)),
67        }
68    }
69
70    /**
71     * Compute and store the chroma of a signal.
72     *
73     * Passing a full song here once instead of streaming smaller parts of the
74     * song will greatly improve accuracy.
75     */
76    #[allow(clippy::missing_errors_doc, clippy::missing_panics_doc)]
77    #[inline]
78    pub fn do_(&mut self, signal: &[f32]) -> AnalysisResult<()> {
79        let stft = stft(signal, Self::WINDOW_SIZE, 2205);
80        let tuning = estimate_tuning(self.sample_rate, &stft, Self::WINDOW_SIZE, 0.01, 12)?;
81        let chroma = chroma_stft(
82            self.sample_rate,
83            &stft,
84            Self::WINDOW_SIZE,
85            self.n_chroma,
86            tuning,
87        )?;
88        self.values_chroma = concatenate![Axis(1), self.values_chroma, chroma];
89        Ok(())
90    }
91
92    /**
93     * Get the song's interval features.
94     *
95     * Return the 6 pitch class set categories, as well as the major, minor,
96     * diminished and augmented triads.
97     *
98     * See this paper <https://speech.di.uoa.gr/ICMC-SMC-2014/images/VOL_2/1461.pdf>
99     * for more information ("Timbre-invariant Audio Features for Style Analysis of Classical
100     * Music").
101     */
102    #[inline]
103    pub fn get_value(&mut self) -> Vec<Feature> {
104        let mut raw_features = chroma_interval_features(&self.values_chroma);
105        let (mut interval_class, mut interval_class_mode) =
106            raw_features.view_mut().split_at(Axis(0), 6);
107        // Compute those two norms separately because the values for the IC1-6 and IC7-10 don't
108        // have the same range.
109        let l2_norm_interval_class = interval_class.dot(&interval_class).sqrt();
110        let l2_norm_interval_class_mode = interval_class_mode.dot(&interval_class_mode).sqrt();
111        if l2_norm_interval_class > 0. {
112            interval_class /= l2_norm_interval_class;
113        }
114        if l2_norm_interval_class_mode > 0. {
115            interval_class_mode /= l2_norm_interval_class_mode;
116        }
117        let mut features = raw_features.mapv_into_any(|x| self.normalize(x)).to_vec();
118
119        let normalized_l2_norm_interval_class =
120            (2. * (l2_norm_interval_class - 0.) / (Self::MAX_L2_INTERVAL - 0.) - 1.).min(1.);
121        features.push(normalized_l2_norm_interval_class);
122        let normalized_l2_norm_interval_class_mode =
123            (2. * (l2_norm_interval_class_mode - 0.) / (Self::MAX_L2_TRIAD - 0.) - 1.).min(1.);
124        features.push(normalized_l2_norm_interval_class_mode);
125        let angle = (20. * l2_norm_interval_class_mode).atan2(l2_norm_interval_class + 1e-12_f64);
126        let normalized_ratio = 2. * (angle - 0.) / (Self::MAX_TRIAD_INTERVAL_RATIO - 0.) - 1.;
127        features.push(normalized_ratio);
128        features
129    }
130}
131
132// Functions below are Rust versions of python notebooks by AudioLabs Erlang
133// (<https://www.audiolabs-erlangen.de/resources/MIR/FMP/C0/C0.html>)
134#[allow(
135    clippy::missing_errors_doc,
136    clippy::missing_panics_doc,
137    clippy::module_name_repetitions
138)]
139#[must_use]
140#[inline]
141pub fn chroma_interval_features(chroma: &Array2<f64>) -> Array1<f64> {
142    let chroma = normalize_feature_sequence(&(chroma * 15.).exp());
143    let templates = arr2(&[
144        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
145        [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
146        [0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
147        [0, 0, 1, 0, 0, 0, 0, 1, 1, 0],
148        [0, 0, 0, 1, 0, 0, 1, 0, 0, 1],
149        [0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
150        [0, 0, 0, 0, 0, 1, 0, 0, 1, 0],
151        [0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
152        [0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
153        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
154        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
155        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
156    ]);
157    let interval_feature_matrix = extract_interval_features(&chroma, &templates);
158    interval_feature_matrix.mean_axis(Axis(1)).unwrap()
159}
160
161#[must_use]
162#[inline]
163pub fn extract_interval_features(chroma: &Array2<f64>, templates: &Array2<i32>) -> Array2<f64> {
164    let mut f_intervals: Array2<f64> = Array::zeros((chroma.shape()[1], templates.shape()[1]));
165    for (template, mut f_interval) in templates
166        .axis_iter(Axis(1))
167        .zip(f_intervals.axis_iter_mut(Axis(1)))
168    {
169        for shift in 0..12 {
170            let mut vec: Vec<i32> = template.to_vec();
171            vec.rotate_right(shift);
172            let rolled = arr1(&vec);
173            let power = Zip::from(chroma.t())
174                .and_broadcast(&rolled)
175                .map_collect(|&f, &s| f.powi(s))
176                .map_axis_mut(Axis(1), |x| x.product());
177            f_interval += &power;
178        }
179    }
180    f_intervals.t().to_owned()
181}
182
183#[inline]
184pub fn normalize_feature_sequence(feature: &Array2<f64>) -> Array2<f64> {
185    let mut normalized_sequence = feature.to_owned();
186    for mut column in normalized_sequence.columns_mut() {
187        let sum: f64 = column.iter().copied().map(f64::abs).sum();
188        if likely(sum >= 0.0001) {
189            column /= sum;
190        }
191    }
192
193    normalized_sequence
194}
195
196// All the functions below are more than heavily inspired from
197// librosa"s code: https://github.com/librosa/librosa/blob/main/librosa/feature/spectral.py#L1165
198// chroma(22050, n_fft=5, n_chroma=12)
199//
200// Could be precomputed, but it takes very little time to compute it
201// on the fly compared to the rest of the functions, and we'd lose the
202// possibility to tweak parameters.
203#[allow(
204    clippy::missing_errors_doc,
205    clippy::missing_panics_doc,
206    clippy::module_name_repetitions,
207    clippy::missing_inline_in_public_items
208)]
209pub fn chroma_filter(
210    sample_rate: u32,
211    n_fft: usize,
212    n_chroma: u32,
213    tuning: f64,
214) -> AnalysisResult<Array2<f64>> {
215    let ctroct = 5.0;
216    let octwidth = 2.;
217    let n_chroma_float = f64::from(n_chroma);
218    let n_chroma2 = (n_chroma_float / 2.0).round();
219
220    let frequencies = Array::linspace(0., f64::from(sample_rate), n_fft + 1);
221
222    let mut freq_bins = frequencies;
223    hz_to_octs_inplace(&mut freq_bins, tuning, n_chroma);
224    freq_bins *= n_chroma_float;
225    freq_bins[0] = 1.5f64.mul_add(-n_chroma_float, freq_bins[1]);
226
227    let mut binwidth_bins = Array::ones(freq_bins.raw_dim());
228    binwidth_bins
229        .slice_mut(s![0..freq_bins.len() - 1])
230        .assign(&(&freq_bins.slice(s![1..]) - &freq_bins.slice(s![..-1])).mapv(|x| x.max(1.)));
231
232    let mut d: Array2<f64> = Array::zeros((n_chroma as usize, (freq_bins).len()));
233    for (idx, mut row) in d.rows_mut().into_iter().enumerate() {
234        #[allow(clippy::cast_precision_loss)]
235        row.fill(idx as f64);
236    }
237    d = -d + &freq_bins;
238
239    d = d + n_chroma2 + 10. * n_chroma_float;
240    d = d % n_chroma_float - n_chroma2;
241    d = d / binwidth_bins;
242    d = (-2. * d.pow2()).exp();
243
244    let mut wts = d;
245    // Normalize by computing the l2-norm over the columns
246    for mut col in wts.columns_mut() {
247        let sum = col.pow2().sum().sqrt();
248        if sum >= f64::MIN_POSITIVE {
249            col /= sum;
250        }
251    }
252
253    // Apply Gaussian tuning curve
254    freq_bins = (-0.5 * ((freq_bins / n_chroma_float - ctroct) / octwidth).powi(2)).exp();
255
256    wts *= &freq_bins;
257
258    // np.roll(), np bro
259    let mut b = Array2::zeros(wts.dim());
260    b.slice_mut(s![-3.., ..]).assign(&wts.slice(s![..3, ..]));
261    b.slice_mut(s![..-3, ..]).assign(&wts.slice(s![3.., ..]));
262
263    wts = b;
264    let non_aliased = 1 + n_fft / 2;
265    Ok(wts.slice_move(s![.., ..non_aliased]))
266}
267
268#[allow(clippy::missing_errors_doc, clippy::missing_panics_doc)]
269#[allow(clippy::missing_inline_in_public_items)]
270pub fn pip_track(
271    sample_rate: u32,
272    spectrum: &Array2<f64>,
273    n_fft: usize,
274) -> AnalysisResult<(Vec<f64>, Vec<f64>)> {
275    let sample_rate_float = f64::from(sample_rate);
276    let fmin = 150.0_f64;
277    let fmax = 4000.0_f64.min(sample_rate_float / 2.0);
278    let threshold = 0.1;
279
280    let fft_freqs = Array::linspace(0., sample_rate_float / 2., 1 + n_fft / 2);
281
282    let length = spectrum.len_of(Axis(0));
283
284    let freq_mask = fft_freqs
285        .iter()
286        .map(|&f| (fmin <= f) && (f < fmax))
287        .collect::<BitVec>();
288
289    let ref_value = spectrum.map_axis(Axis(0), |x| {
290        let first: f64 = *x.first().expect("empty spectrum axis");
291        let max = x.fold(first, |acc, &elem| acc.max(elem));
292        threshold * max
293    });
294
295    // compute number of taken columns and beginning / end indices
296    let freq_mask_len = freq_mask.len();
297    let (taken_columns, beginning, end) = freq_mask.iter().enumerate().fold(
298        (0, freq_mask_len, 0),
299        |(taken, beginning, end), (i, b)| {
300            b.then(|| (taken + 1, beginning.min(i), end.max(i + 1)))
301                .unwrap_or((taken, beginning, end))
302        },
303    );
304
305    // Validate that a valid frequency range was found
306    if beginning >= end {
307        return Err(AnalysisError::AnalysisError(String::from(
308            "in chroma: no valid frequency range found",
309        )));
310    }
311    // There will be at most taken_columns * length elements in pitches / mags
312    let mut pitches = Vec::with_capacity(taken_columns * length);
313    let mut mags = Vec::with_capacity(taken_columns * length);
314
315    let zipped = Zip::indexed(spectrum.slice(s![beginning..end - 3, ..]))
316        .and(spectrum.slice(s![beginning + 1..end - 2, ..]))
317        .and(spectrum.slice(s![beginning + 2..end - 1, ..]));
318
319    // No need to handle the last column, since freq_mask[length - 1] is
320    // always going to be `false` for 22.5kHz
321    zipped.for_each(|(i, j), &before_elem, &elem, &after_elem| {
322        if elem > ref_value[j] && after_elem <= elem && before_elem < elem {
323            let avg = 0.5 * (after_elem - before_elem);
324            let mut shift = 2f64.mul_add(elem, -after_elem - before_elem);
325            if shift.abs() < f64::MIN_POSITIVE {
326                shift += 1.;
327            }
328            shift = avg / shift;
329            #[allow(clippy::cast_precision_loss)]
330            pitches.push(((i + beginning + 1) as f64 + shift) * sample_rate_float / n_fft as f64);
331            mags.push((0.5 * avg).mul_add(shift, elem));
332        }
333    });
334
335    Ok((pitches, mags))
336}
337
338// Only use this with strictly positive `frequencies`.
339#[allow(clippy::missing_errors_doc, clippy::missing_panics_doc)]
340#[inline]
341pub fn pitch_tuning(
342    frequencies: &mut Array1<f64>,
343    resolution: f64,
344    bins_per_octave: u32,
345) -> AnalysisResult<f64> {
346    if unlikely(frequencies.is_empty()) {
347        return Ok(0.0);
348    }
349    hz_to_octs_inplace(frequencies, 0.0, 12);
350    frequencies.mapv_inplace(|x| f64::from(bins_per_octave) * x % 1.0);
351
352    // Put everything between -0.5 and 0.5.
353    frequencies.mapv_inplace(|x| if x >= 0.5 { x - 1. } else { x });
354
355    #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
356    let indexes = ((frequencies.to_owned() - -0.5) / resolution).mapv(|x| x as usize);
357    #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
358    let mut counts: Array1<usize> = Array::zeros(((0.5 - -0.5) / resolution) as usize);
359    for &idx in &indexes {
360        counts[idx] += 1;
361    }
362    let max_index = counts
363        .argmax()
364        .map_err_unlikely(|e| AnalysisError::AnalysisError(format!("in chroma: {e}")))?;
365
366    // Return the bin with the most reoccurring frequency.
367    #[allow(clippy::cast_precision_loss)]
368    Ok((100. * resolution).mul_add(max_index as f64, -50.) / 100.)
369}
370
371#[allow(clippy::missing_errors_doc, clippy::missing_panics_doc)]
372#[inline]
373pub fn estimate_tuning(
374    sample_rate: u32,
375    spectrum: &Array2<f64>,
376    n_fft: usize,
377    resolution: f64,
378    bins_per_octave: u32,
379) -> AnalysisResult<f64> {
380    let (pitch, mag) = pip_track(sample_rate, spectrum, n_fft)?;
381
382    let (filtered_pitch, filtered_mag): (Vec<N64>, Vec<N64>) = pitch
383        .iter()
384        .zip(&mag)
385        .filter(|&(&p, _)| p > 0.)
386        .map(|(x, y)| (n64(*x), n64(*y)))
387        .unzip();
388
389    if unlikely(pitch.is_empty()) {
390        return Ok(0.);
391    }
392
393    let threshold: N64 = Array::from(filtered_mag.clone())
394        .quantile_axis_mut(Axis(0), n64(0.5), &Midpoint)
395        .map_err_unlikely(|e| AnalysisError::AnalysisError(format!("in chroma: {e}")))?
396        .into_scalar();
397    let mut pitch = filtered_pitch
398        .iter()
399        .zip(&filtered_mag)
400        .filter_map(|(&p, &m)| if m >= threshold { Some(p.into()) } else { None })
401        .collect::<Array1<f64>>();
402    pitch_tuning(&mut pitch, resolution, bins_per_octave)
403}
404
405#[allow(
406    clippy::missing_errors_doc,
407    clippy::missing_panics_doc,
408    clippy::module_name_repetitions
409)]
410#[inline]
411pub fn chroma_stft(
412    sample_rate: u32,
413    spectrum: &Array2<f64>, // shape: (window_length / 2 + 1, signal.len().div_ceil(hop_length))
414    n_fft: usize,
415    n_chroma: u32,
416    tuning: f64,
417) -> AnalysisResult<Array2<f64>> {
418    let mut raw_chroma = chroma_filter(sample_rate, n_fft, n_chroma, tuning)?;
419
420    raw_chroma = raw_chroma.dot(&spectrum.pow2());
421
422    // We want to maximize cache locality, and are iterating over columns,
423    // so let's make sure our array is in column-major order.
424    raw_chroma = raw_chroma
425        .to_shape((raw_chroma.dim(), Order::ColumnMajor))
426        .map_err_unlikely(|_| {
427            AnalysisError::AnalysisError(String::from("in chroma: failed to reorder array"))
428        })?
429        .to_owned();
430
431    Zip::from(raw_chroma.columns_mut()).for_each(|mut row| {
432        let sum = row.sum(); // we know that our values are positive, so no need to use abs
433        if sum >= f64::MIN_POSITIVE {
434            row /= sum;
435        }
436    });
437
438    Ok(raw_chroma)
439}
440
441#[cfg(test)]
442mod test {
443    use super::*;
444    use crate::{
445        SAMPLE_RATE,
446        decoder::{Decoder as _, MecompDecoder as Decoder},
447        utils::stft,
448    };
449    use ndarray::{Array2, arr1, arr2};
450    use ndarray_npy::ReadNpyExt as _;
451    use std::{fs::File, path::Path};
452
453    #[test]
454    fn test_chroma_interval_features() {
455        let file = File::open("data/chroma.npy").unwrap();
456        let chroma = Array2::<f64>::read_npy(file).unwrap();
457        let features = chroma_interval_features(&chroma);
458        let expected_features = arr1(&[
459            0.038_602_84,
460            0.021_852_81,
461            0.042_243_79,
462            0.063_852_78,
463            0.073_111_48,
464            0.025_125_66,
465            0.003_198_99,
466            0.003_113_08,
467            0.001_074_33,
468            0.002_418_61,
469        ]);
470        for (expected, actual) in expected_features.iter().zip(&features) {
471            assert!(
472                0.000_000_01 > (expected - actual.abs()),
473                "{expected} !~= {actual}"
474            );
475        }
476    }
477
478    #[test]
479    fn test_extract_interval_features() {
480        let file = File::open("data/chroma-interval.npy").unwrap();
481        let chroma = Array2::<f64>::read_npy(file).unwrap();
482        let templates = arr2(&[
483            [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
484            [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
485            [0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
486            [0, 0, 1, 0, 0, 0, 0, 1, 1, 0],
487            [0, 0, 0, 1, 0, 0, 1, 0, 0, 1],
488            [0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
489            [0, 0, 0, 0, 0, 1, 0, 0, 1, 0],
490            [0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
491            [0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
492            [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
493            [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
494            [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
495        ]);
496
497        let file = File::open("data/interval-feature-matrix.npy").unwrap();
498        let expected_interval_features = Array2::<f64>::read_npy(file).unwrap();
499
500        let interval_features = extract_interval_features(&chroma, &templates);
501        for (expected, actual) in expected_interval_features
502            .iter()
503            .zip(interval_features.iter())
504        {
505            assert!(
506                0.000_000_1 > (expected - actual).abs(),
507                "{expected} !~= {actual}"
508            );
509        }
510    }
511
512    #[test]
513    fn test_normalize_feature_sequence() {
514        let array = arr2(&[[0.1, 0.3, 0.4], [1.1, 0.53, 1.01]]);
515        let expected_array = arr2(&[
516            [0.083_333_33, 0.361_445_78, 0.283_687_94],
517            [0.916_666_67, 0.638_554_22, 0.716_312_06],
518        ]);
519
520        let normalized_array = normalize_feature_sequence(&array);
521
522        assert!(!array.is_empty() && !expected_array.is_empty());
523
524        for (expected, actual) in normalized_array.iter().zip(expected_array.iter()) {
525            assert!(
526                0.000_000_1 > (expected - actual).abs(),
527                "{expected} !~= {actual}"
528            );
529        }
530    }
531
532    #[test]
533    fn test_chroma_desc() {
534        let song = Decoder::new()
535            .unwrap()
536            .decode(Path::new("data/s16_mono_22_5kHz.flac"))
537            .unwrap();
538        let mut chroma_desc = ChromaDesc::new(SAMPLE_RATE, 12);
539        chroma_desc.do_(&song.samples).unwrap();
540        let expected_values = [
541            -0.342_925_13,
542            -0.628_034_23,
543            -0.280_950_96,
544            0.086_864_59,
545            0.244_460_82,
546            -0.572_325_7,
547            0.232_920_65,
548            0.199_811_46,
549            -0.585_944_06,
550            -0.067_842_96,
551        ];
552        for (expected, actual) in expected_values.iter().zip(chroma_desc.get_value().iter()) {
553            // original test wanted absolute error < 0.0000001
554            let relative_error = (expected - actual).abs() / expected.abs();
555            assert!(
556                relative_error < 0.01,
557                "relative error: {relative_error}, expected: {expected}, actual: {actual}"
558            );
559        }
560    }
561
562    #[test]
563    fn test_chroma_stft_decode() {
564        let signal = Decoder::new()
565            .unwrap()
566            .decode(Path::new("data/s16_mono_22_5kHz.flac"))
567            .unwrap()
568            .samples;
569        let stft = stft(&signal, 8192, 2205);
570
571        let file = File::open("data/chroma.npy").unwrap();
572        let expected_chroma = Array2::<f64>::read_npy(file).unwrap();
573
574        let chroma = chroma_stft(22050, &stft, 8192, 12, -0.049_999_999_999_999_99).unwrap();
575
576        assert!(!chroma.is_empty() && !expected_chroma.is_empty());
577
578        for (expected, actual) in expected_chroma.iter().zip(chroma.iter()) {
579            // original test wanted absolute error < 0.0000001
580            let relative_error = (expected - actual).abs() / expected.abs();
581            assert!(
582                relative_error < 0.01,
583                "relative error: {relative_error}, expected: {expected}, actual: {actual}"
584            );
585        }
586    }
587
588    #[test]
589    fn test_estimate_tuning() {
590        let file = File::open("data/spectrum-chroma.npy").unwrap();
591        let arr = Array2::<f64>::read_npy(file).unwrap();
592
593        let tuning = estimate_tuning(22050, &arr, 2048, 0.01, 12).unwrap();
594        assert!(
595            0.000_001 > (-0.099_999_999_999_999_98 - tuning).abs(),
596            "{tuning} !~= -0.09999999999999998"
597        );
598    }
599
600    #[test]
601    fn test_chroma_estimate_tuning_empty_fix() {
602        assert!(0. == estimate_tuning(22050, &Array2::zeros((8192, 1)), 8192, 0.01, 12).unwrap());
603    }
604
605    #[test]
606    fn test_estimate_tuning_decode() {
607        let signal = Decoder::new()
608            .unwrap()
609            .decode(Path::new("data/s16_mono_22_5kHz.flac"))
610            .unwrap()
611            .samples;
612        let stft = stft(&signal, 8192, 2205);
613
614        let tuning = estimate_tuning(22050, &stft, 8192, 0.01, 12).unwrap();
615        assert!(
616            0.000_001 > (-0.049_999_999_999_999_99 - tuning).abs(),
617            "{tuning} !~= -0.04999999999999999"
618        );
619    }
620
621    #[test]
622    fn test_pitch_tuning() {
623        let file = File::open("data/pitch-tuning.npy").unwrap();
624        let mut pitch = Array1::<f64>::read_npy(file).unwrap();
625        let tuned = pitch_tuning(&mut pitch, 0.05, 12).unwrap();
626        assert!(f64::EPSILON > (tuned + 0.1).abs(), "{tuned} != -0.1");
627    }
628
629    #[test]
630    fn test_pitch_tuning_no_frequencies() {
631        let mut frequencies = arr1(&[]);
632        let tuned = pitch_tuning(&mut frequencies, 0.05, 12).unwrap();
633        assert!(f64::EPSILON > tuned.abs(), "{tuned} != 0");
634    }
635
636    #[test]
637    fn test_pip_track() {
638        let file = File::open("data/spectrum-chroma.npy").unwrap();
639        let spectrum = Array2::<f64>::read_npy(file).unwrap();
640
641        let mags_file = File::open("data/spectrum-chroma-mags.npy").unwrap();
642        let expected_mags = Array1::<f64>::read_npy(mags_file).unwrap();
643
644        let pitches_file = File::open("data/spectrum-chroma-pitches.npy").unwrap();
645        let expected_pitches = Array1::<f64>::read_npy(pitches_file).unwrap();
646
647        let (mut pitches, mut mags) = pip_track(22050, &spectrum, 2048).unwrap();
648        pitches.sort_by(|a, b| a.partial_cmp(b).unwrap());
649        mags.sort_by(|a, b| a.partial_cmp(b).unwrap());
650
651        for (expected_pitches, actual_pitches) in expected_pitches.iter().zip(pitches.iter()) {
652            assert!(
653                0.000_000_01 > (expected_pitches - actual_pitches).abs(),
654                "{expected_pitches} !~= {actual_pitches}"
655            );
656        }
657        for (expected_mags, actual_mags) in expected_mags.iter().zip(mags.iter()) {
658            assert!(
659                0.000_000_01 > (expected_mags - actual_mags).abs(),
660                "{expected_mags} !~= {actual_mags}"
661            );
662        }
663    }
664
665    #[test]
666    fn test_chroma_filter() {
667        let file = File::open("data/chroma-filter.npy").unwrap();
668        let expected_filter = Array2::<f64>::read_npy(file).unwrap();
669
670        let filter = chroma_filter(22050, 2048, 12, -0.1).unwrap();
671
672        assert!(filter.iter().all(|&x| x > 0.));
673
674        for (expected, actual) in expected_filter.iter().zip(filter.iter()) {
675            assert!(
676                0.000_000_001 > (expected - actual).abs(),
677                "{expected} !~= {actual}"
678            );
679        }
680    }
681
682    #[rstest::rstest]
683    // High 6 should be a major triad, 7 minor, 8 diminished and 9 augmented.
684    #[case::major_triad("data/chroma/Cmaj.ogg", 6)]
685    #[case::major_triad("data/chroma/Dmaj.ogg", 6)]
686    #[case::minor_triad("data/chroma/Cmin.ogg", 7)]
687    #[case::diminished_triad("data/chroma/Cdim.ogg", 8)]
688    #[case::augmented_triad("data/chroma/Caug.ogg", 9)]
689    fn test_end_result_triads(
690        #[case] path: &str,
691        #[case] expected_dominant_chroma_feature_index: usize,
692    ) {
693        let song = Decoder::new().unwrap().decode(Path::new(path)).unwrap();
694        let mut chroma_desc = ChromaDesc::new(SAMPLE_RATE, 12);
695        chroma_desc.do_(&song.samples).unwrap();
696        let chroma_values = chroma_desc.get_value();
697
698        let mut indices: Vec<usize> = (0..chroma_values.len()).collect();
699        indices.sort_by(|&i, &j| chroma_values[j].partial_cmp(&chroma_values[i]).unwrap());
700        assert!(indices[0] == expected_dominant_chroma_feature_index);
701        for (i, v) in chroma_values.into_iter().enumerate() {
702            if i >= 6 && i <= 10 {
703                if i == expected_dominant_chroma_feature_index {
704                    assert!(v > 0.8);
705                } else {
706                    assert!(v < 0.0);
707                }
708            }
709        }
710    }
711
712    #[test]
713    fn test_end_l2_norm_dyad() {
714        let song = Decoder::new()
715            .unwrap()
716            .decode(Path::new("data/chroma/dyad_tritone_IC6.ogg"))
717            .unwrap();
718        let mut chroma_desc = ChromaDesc::new(SAMPLE_RATE, 12);
719        chroma_desc.do_(&song.samples).unwrap();
720        let chroma_values = chroma_desc.get_value();
721        assert!(chroma_values[10] > 0.9);
722    }
723
724    #[test]
725    fn test_end_l2_norm_mode() {
726        let song = Decoder::new()
727            .unwrap()
728            .decode(Path::new("data/chroma/Cmaj_triads.ogg"))
729            .unwrap();
730        let mut chroma_desc = ChromaDesc::new(SAMPLE_RATE, 12);
731        chroma_desc.do_(&song.samples).unwrap();
732        let chroma_values = chroma_desc.get_value();
733        assert!(chroma_values[11] > 0.9);
734    }
735
736    #[test]
737    fn test_end_l2_norm_ratio() {
738        let song = Decoder::new()
739            .unwrap()
740            .decode(Path::new("data/chroma/triad_aug_maximize_ratio.ogg"))
741            .unwrap();
742        let mut chroma_desc = ChromaDesc::new(SAMPLE_RATE, 12);
743        chroma_desc.do_(&song.samples).unwrap();
744        let chroma_values = chroma_desc.get_value();
745        assert!(chroma_values[12] > 0.7);
746    }
747
748    #[rstest::rstest]
749    // Test all 12 intervals.
750    #[case::minor_second("data/chroma/minor_second.ogg", 0)]
751    #[case::major_second("data/chroma/major_second.ogg", 1)]
752    #[case::minor_third("data/chroma/minor_third.ogg", 2)]
753    #[case::major_third("data/chroma/major_third.ogg", 3)]
754    #[case::perfect_fourth("data/chroma/perfect_fourth.ogg", 4)]
755    #[case::tritone("data/chroma/tritone.ogg", 5)]
756    #[case::perfect_fifth("data/chroma/perfect_fifth.ogg", 4)]
757    #[case::minor_sixth("data/chroma/minor_sixth.ogg", 3)]
758    #[case::major_sixth("data/chroma/major_sixth.ogg", 2)]
759    #[case::minor_seventh("data/chroma/minor_seventh.ogg", 1)]
760    #[case::major_seventh("data/chroma/major_seventh.ogg", 0)]
761    fn test_end_result_intervals(
762        #[case] path: &str,
763        #[case] expected_dominant_chroma_feature_index: usize,
764    ) {
765        let song = Decoder::new().unwrap().decode(Path::new(path)).unwrap();
766        let mut chroma_desc = ChromaDesc::new(SAMPLE_RATE, 12);
767        chroma_desc.do_(&song.samples).unwrap();
768        let chroma_values = chroma_desc.get_value();
769
770        let mut indices: Vec<usize> = (0..chroma_values.len()).collect();
771        indices.sort_by(|&i, &j| chroma_values[j].partial_cmp(&chroma_values[i]).unwrap());
772        assert_eq!(indices[0], expected_dominant_chroma_feature_index);
773        for (i, v) in chroma_values.into_iter().enumerate() {
774            if i < 6 {
775                if i == expected_dominant_chroma_feature_index {
776                    assert!(v > 0.9);
777                } else {
778                    assert!(v < 0.0);
779                }
780            }
781        }
782    }
783}