Skip to main content

mecomp_analysis/
chroma.rs

1//! Chroma feature extraction module.
2//!
3//! Contains functions to compute the chromagram of a song, and
4//! then from this chromagram extract the song's tone and mode
5//! (minor / major).
6use crate::Feature;
7
8use super::errors::{AnalysisError, AnalysisResult};
9use super::utils::{Normalize, hz_to_octs_inplace, stft};
10use bitvec::vec::BitVec;
11use likely_stable::{LikelyResult, likely, unlikely};
12use ndarray::{Array, Array1, Array2, Axis, Order, Zip, arr2, concatenate, s};
13use ndarray_stats::QuantileExt;
14use noisy_float::prelude::*;
15
16/**
17 * General object holding the chroma descriptor.
18 *
19 * Current chroma descriptors are interval features (see
20 * <https://speech.di.uoa.gr/ICMC-SMC-2014/images/VOL_2/1461.pdf>).
21 *
22 * Contrary to the other descriptors that can be used with streaming
23 * without consequences, this one performs better if the full song is used at
24 * once.
25 */
26#[derive(Debug, Clone)]
27#[allow(clippy::module_name_repetitions)]
28pub struct ChromaDesc {
29    sample_rate: u32,
30    n_chroma: u32,
31    values_chroma: Array2<f32>,
32}
33
34impl Normalize for ChromaDesc {
35    const MAX_VALUE: Feature = 1.0;
36    const MIN_VALUE: Feature = 0.;
37}
38
39impl ChromaDesc {
40    pub const WINDOW_SIZE: usize = 8192;
41    /// The theoretical maximum value for IC1-6 is each value at (1/2)².
42    /// The reason is that `extract_interval_features` computes the product of the
43    /// L1-normalized chroma vector (so all of its values are <= 1) by itself.
44    /// The maximum value of this is all coordinates to 1/2 (since dyads will
45    /// select three values). The maximum of this is then (1/2)², so the maximum of its
46    /// L2 norm is this sqrt(2 * (1/2)²) ~= 0.62. However, real-life simulations shown
47    /// that 0.25 is a good ceiling value (see tests).
48    pub const MAX_L2_INTERVAL: f32 = 0.25;
49    /// The theoretical maximum value for IC7-10 is each value at (1/3)³.
50    /// The reason is that `extract_interval_features` computes the product of the
51    /// L1-normalized chroma vector (so all of its values are <= 1) by itself.
52    /// The maximum value of this is all coordinates to 1/3 (since triads will
53    /// select three values). The maximum of this is then (1/3)³, so the maximum of its
54    /// L2 norm is this sqrt(4 * (1/3)³) ~= 0.074. However, real-life simulations shown
55    /// that 0.025 is a good ceiling value (see tests).
56    pub const MAX_L2_TRIAD: f32 = 0.025;
57    /// We are using atan2 to keep the ratio bounded.
58    pub const MAX_TRIAD_INTERVAL_RATIO: f32 = std::f32::consts::FRAC_PI_2;
59    #[must_use]
60    #[inline]
61    pub fn new(sample_rate: u32, n_chroma: u32) -> Self {
62        Self {
63            sample_rate,
64            n_chroma,
65            values_chroma: Array2::zeros((n_chroma as usize, 0)),
66        }
67    }
68
69    /**
70     * Compute and store the chroma of a signal.
71     *
72     * Passing a full song here once instead of streaming smaller parts of the
73     * song will greatly improve accuracy.
74     */
75    #[allow(clippy::missing_errors_doc, clippy::missing_panics_doc)]
76    #[inline]
77    pub fn do_(&mut self, signal: &[f32]) -> AnalysisResult<()> {
78        let stft = stft(signal, Self::WINDOW_SIZE, 2205);
79        let tuning = estimate_tuning(self.sample_rate, &stft, Self::WINDOW_SIZE, 0.01, 12)?;
80        let chroma = chroma_stft(
81            self.sample_rate,
82            &stft,
83            Self::WINDOW_SIZE,
84            self.n_chroma,
85            tuning,
86        )?;
87        self.values_chroma = concatenate![Axis(1), self.values_chroma, chroma];
88        Ok(())
89    }
90
91    /**
92     * Get the song's interval features.
93     *
94     * Return the 6 pitch class set categories, as well as the major, minor,
95     * diminished and augmented triads.
96     *
97     * See this paper <https://speech.di.uoa.gr/ICMC-SMC-2014/images/VOL_2/1461.pdf>
98     * for more information ("Timbre-invariant Audio Features for Style Analysis of Classical
99     * Music").
100     */
101    #[inline]
102    pub fn get_value(&mut self) -> Vec<Feature> {
103        let mut raw_features = chroma_interval_features(&self.values_chroma);
104        let (mut interval_class, mut interval_class_mode) =
105            raw_features.view_mut().split_at(Axis(0), 6);
106        // Compute those two norms separately because the values for the IC1-6 and IC7-10 don't
107        // have the same range.
108        let l2_norm_interval_class = interval_class.dot(&interval_class).sqrt();
109        let l2_norm_interval_class_mode = interval_class_mode.dot(&interval_class_mode).sqrt();
110        if l2_norm_interval_class > 0. {
111            interval_class /= l2_norm_interval_class;
112        }
113        if l2_norm_interval_class_mode > 0. {
114            interval_class_mode /= l2_norm_interval_class_mode;
115        }
116        let mut features = raw_features.mapv_into_any(|x| self.normalize(x)).to_vec();
117
118        let normalized_l2_norm_interval_class =
119            (2. * l2_norm_interval_class / Self::MAX_L2_INTERVAL - 1.).min(1.);
120        features.push(normalized_l2_norm_interval_class);
121        let normalized_l2_norm_interval_class_mode =
122            (2. * l2_norm_interval_class_mode / Self::MAX_L2_TRIAD - 1.).min(1.);
123        features.push(normalized_l2_norm_interval_class_mode);
124        let angle = (20. * l2_norm_interval_class_mode).atan2(l2_norm_interval_class + 1e-12_f32);
125        let normalized_ratio = 2. * angle / Self::MAX_TRIAD_INTERVAL_RATIO - 1.;
126        features.push(normalized_ratio);
127        features
128    }
129}
130
131// Functions below are Rust versions of python notebooks by AudioLabs Erlang
132// (<https://www.audiolabs-erlangen.de/resources/MIR/FMP/C0/C0.html>)
133#[allow(
134    clippy::missing_errors_doc,
135    clippy::missing_panics_doc,
136    clippy::module_name_repetitions
137)]
138#[must_use]
139#[inline]
140pub fn chroma_interval_features(chroma: &Array2<f32>) -> Array1<f32> {
141    let chroma = normalize_feature_sequence(&(chroma * 15.).exp());
142    let templates = arr2(&[
143        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
144        [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
145        [0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
146        [0, 0, 1, 0, 0, 0, 0, 1, 1, 0],
147        [0, 0, 0, 1, 0, 0, 1, 0, 0, 1],
148        [0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
149        [0, 0, 0, 0, 0, 1, 0, 0, 1, 0],
150        [0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
151        [0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
152        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
153        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
154        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
155    ]);
156    let interval_feature_matrix = extract_interval_features(&chroma, &templates);
157    interval_feature_matrix.mean_axis(Axis(1)).unwrap()
158}
159
160#[must_use]
161#[inline]
162pub fn extract_interval_features(chroma: &Array2<f32>, templates: &Array2<i32>) -> Array2<f32> {
163    let n_templates = templates.shape()[1];
164    let n_chroma = chroma.shape()[0]; // should be 12
165    let n_cols = chroma.shape()[1];
166
167    let chroma_t = chroma.t();
168    let mut f_intervals: Array2<f32> = Array::zeros((n_templates, n_cols));
169
170    // precompute active indices for all templates
171    let active_indices: Vec<Vec<usize>> = templates
172        .axis_iter(Axis(1))
173        .map(|t| {
174            t.iter()
175                .enumerate()
176                .filter_map(|(i, &v)| (v == 1).then_some(i))
177                .collect()
178        })
179        .collect();
180
181    // For each column in chroma, compute the product of values at shifted active indices
182    for (col_idx, col) in chroma_t.rows().into_iter().enumerate() {
183        for (tmpl_idx, indices) in active_indices.iter().enumerate() {
184            let sum = (0..n_chroma)
185                .map(|shift| {
186                    indices
187                        .iter()
188                        .map(|&idx| col[(idx + shift) % n_chroma])
189                        .product::<f32>()
190                })
191                .sum();
192            f_intervals[(tmpl_idx, col_idx)] = sum;
193        }
194    }
195
196    f_intervals
197}
198
199#[inline]
200pub fn normalize_feature_sequence(feature: &Array2<f32>) -> Array2<f32> {
201    let mut normalized_sequence = feature.to_owned();
202    for mut column in normalized_sequence.columns_mut() {
203        let sum: f32 = column.iter().copied().map(f32::abs).sum();
204        if likely(sum >= 0.0001) {
205            column /= sum;
206        }
207    }
208
209    normalized_sequence
210}
211
212// All the functions below are more than heavily inspired from
213// librosa"s code: https://github.com/librosa/librosa/blob/main/librosa/feature/spectral.py#L1165
214// chroma(22050, n_fft=5, n_chroma=12)
215//
216// Could be precomputed, but it takes very little time to compute it
217// on the fly compared to the rest of the functions, and we'd lose the
218// possibility to tweak parameters.
219#[allow(
220    clippy::missing_errors_doc,
221    clippy::missing_panics_doc,
222    clippy::module_name_repetitions,
223    clippy::missing_inline_in_public_items
224)]
225pub fn chroma_filter(
226    sample_rate: u32,
227    n_fft: usize,
228    n_chroma: u32,
229    tuning: f32,
230) -> AnalysisResult<Array2<f32>> {
231    let ctroct = 5.0;
232    let octwidth = 2.;
233    #[allow(clippy::cast_precision_loss)]
234    let n_chroma2 = (n_chroma >> 1) as f32;
235    #[allow(clippy::cast_precision_loss)]
236    let n_chroma_float = n_chroma as f32;
237
238    #[allow(clippy::cast_precision_loss)]
239    let frequencies = Array::linspace(0., sample_rate as f32, n_fft + 1);
240
241    let mut freq_bins = frequencies;
242    hz_to_octs_inplace(&mut freq_bins, tuning, n_chroma);
243    freq_bins *= n_chroma_float;
244    freq_bins[0] = (1.5).mul_add(-n_chroma_float, freq_bins[1]);
245    let mut binwidth_bins = Array::ones(freq_bins.raw_dim());
246    binwidth_bins
247        .slice_mut(s![0..freq_bins.len() - 1])
248        .assign(&(&freq_bins.slice(s![1..]) - &freq_bins.slice(s![..-1])).mapv(|x| x.max(1.)));
249
250    let mut d: Array2<f32> = Array::zeros((n_chroma as usize, (freq_bins).len()));
251    for (idx, mut row) in d.rows_mut().into_iter().enumerate() {
252        #[allow(clippy::cast_precision_loss)]
253        row.fill(idx as f32);
254    }
255
256    d.zip_mut_with(&freq_bins, |d_elem, &fb| {
257        let x = -*d_elem + fb;
258        let x = n_chroma_float.mul_add(10., x + n_chroma2);
259        *d_elem = x % n_chroma_float - n_chroma2;
260    });
261    d.zip_mut_with(&binwidth_bins, |d_elem, &bb| {
262        let x = *d_elem / bb;
263        *d_elem = (-2. * x * x).exp();
264    });
265
266    let mut wts = d;
267    // Normalize by computing the l2-norm over the columns
268    for mut col in wts.columns_mut() {
269        let sum = col.pow2().sum().sqrt();
270        if sum >= f32::MIN_POSITIVE {
271            col /= sum;
272        }
273    }
274
275    // Apply Gaussian tuning curve
276    freq_bins = (-0.5 * ((freq_bins / n_chroma_float - ctroct) / octwidth).powi(2)).exp();
277
278    wts *= &freq_bins;
279
280    // np.roll(), np bro
281    let mut b = Array2::zeros(wts.dim());
282    b.slice_mut(s![-3.., ..]).assign(&wts.slice(s![..3, ..]));
283    b.slice_mut(s![..-3, ..]).assign(&wts.slice(s![3.., ..]));
284
285    wts = b;
286    let non_aliased = 1 + n_fft / 2;
287    Ok(wts.slice_move(s![.., ..non_aliased]))
288}
289
290#[allow(clippy::missing_errors_doc, clippy::missing_panics_doc)]
291#[allow(clippy::missing_inline_in_public_items)]
292pub fn pip_track(
293    sample_rate: u32,
294    spectrum: &Array2<f32>,
295    n_fft: usize,
296) -> AnalysisResult<(Vec<f32>, Vec<f32>)> {
297    #[allow(clippy::cast_precision_loss)]
298    let sample_rate_float = sample_rate as f32;
299    let fmin = 150.0;
300    let fmax = 4000.0.min(sample_rate_float / 2.0);
301    let threshold = 0.1;
302
303    let fft_freqs = Array::linspace(0., sample_rate_float / 2., 1 + n_fft / 2);
304
305    let length = spectrum.shape()[1];
306
307    let freq_mask = fft_freqs
308        .iter()
309        .map(|&f| (fmin <= f) && (f < fmax))
310        .collect::<BitVec>();
311
312    let ref_value = spectrum.map_axis(Axis(0), |x| {
313        let first = *x.first().expect("empty spectrum axis");
314        let max = x.fold(first, |acc, &elem| acc.max(elem));
315        threshold * max
316    });
317
318    // compute number of taken columns and beginning / end indices
319    let freq_mask_len = freq_mask.len();
320    let (taken_columns, beginning, end) = freq_mask.iter().enumerate().fold(
321        (0, freq_mask_len, 0),
322        |(taken, beginning, end), (i, b)| {
323            b.then(|| (taken + 1, beginning.min(i), end.max(i + 1)))
324                .unwrap_or((taken, beginning, end))
325        },
326    );
327
328    // Validate that a valid frequency range was found
329    if beginning >= end {
330        return Err(AnalysisError::AnalysisError(String::from(
331            "in chroma: no valid frequency range found",
332        )));
333    }
334    // There will be at most taken_columns * length elements in pitches / mags
335    let mut pitches = Vec::with_capacity(taken_columns * length);
336    let mut mags = Vec::with_capacity(taken_columns * length);
337
338    let zipped = Zip::indexed(spectrum.slice(s![beginning..end - 3, ..]))
339        .and(spectrum.slice(s![beginning + 1..end - 2, ..]))
340        .and(spectrum.slice(s![beginning + 2..end - 1, ..]));
341
342    // No need to handle the last column, since freq_mask[length - 1] is
343    // always going to be `false` for 22.5kHz
344    zipped.for_each(|(i, j), &before_elem, &elem, &after_elem| {
345        if elem > ref_value[j] && after_elem <= elem && before_elem < elem {
346            let avg = 0.5 * (after_elem - before_elem);
347            let mut shift = (2.0).mul_add(elem, -after_elem - before_elem);
348            if shift.abs() < f32::MIN_POSITIVE {
349                shift += 1.;
350            }
351            shift = avg / shift;
352            #[allow(clippy::cast_precision_loss)]
353            pitches.push(((i + beginning + 1) as f32 + shift) * sample_rate_float / n_fft as f32);
354            mags.push((0.5 * avg).mul_add(shift, elem));
355        }
356    });
357
358    Ok((pitches, mags))
359}
360
361// Only use this with strictly positive `frequencies`.
362#[allow(clippy::missing_errors_doc, clippy::missing_panics_doc)]
363#[inline]
364pub fn pitch_tuning(
365    mut frequencies: Array1<f32>,
366    resolution: f32,
367    bins_per_octave: u32,
368) -> AnalysisResult<f32> {
369    if unlikely(frequencies.is_empty()) {
370        return Ok(0.0);
371    }
372    hz_to_octs_inplace(&mut frequencies, 0.0, 12);
373    #[allow(clippy::cast_precision_loss)]
374    frequencies.mapv_inplace(|x| (bins_per_octave as f32 * x).fract());
375
376    // Wrap values from [0,1) to [-0.5, 0.5), then shift back up to [0,1)
377    frequencies.mapv_inplace(|x| if x >= 0.5 { x - 0.5 } else { x + 0.5 });
378
379    #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
380    let indexes = (frequencies / resolution).mapv(|x| x as usize);
381    #[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
382    let mut counts: Array1<usize> = Array::zeros(resolution.recip() as usize);
383    for &idx in &indexes {
384        counts[idx] += 1;
385    }
386    let max_index = counts
387        .argmax()
388        .map_err_unlikely(|e| AnalysisError::AnalysisError(format!("in chroma: {e}")))?;
389
390    // Return the bin with the most reoccurring frequency.
391    #[allow(clippy::cast_precision_loss)]
392    Ok((100. * resolution).mul_add(max_index as f32, -50.) / 100.)
393}
394
395#[allow(clippy::missing_errors_doc, clippy::missing_panics_doc)]
396#[inline]
397pub fn estimate_tuning(
398    sample_rate: u32,
399    spectrum: &Array2<f32>,
400    n_fft: usize,
401    resolution: f32,
402    bins_per_octave: u32,
403) -> AnalysisResult<f32> {
404    let (pitch, mag) = pip_track(sample_rate, spectrum, n_fft)?;
405
406    if unlikely(pitch.is_empty()) {
407        return Ok(0.);
408    }
409
410    let (filtered_pitch, filtered_mag): (Vec<N32>, Vec<N32>) = pitch
411        .iter()
412        .zip(&mag)
413        .filter(|&(&p, _)| p > 0.)
414        .map(|(x, y)| (n32(*x), n32(*y)))
415        .unzip();
416
417    let mut mag_copy = filtered_mag.clone();
418    let mid = mag_copy.len() / 2;
419    let threshold = *mag_copy
420        .select_nth_unstable_by(mid, |a, b| a.partial_cmp(b).unwrap())
421        .1;
422
423    let pitch = filtered_pitch
424        .iter()
425        .zip(&filtered_mag)
426        .filter_map(
427            |(&p, &m)| {
428                if m >= threshold { Some(p.into()) } else { None }
429            },
430        )
431        .collect::<Array1<f32>>();
432    pitch_tuning(pitch, resolution, bins_per_octave)
433}
434
435#[allow(
436    clippy::missing_errors_doc,
437    clippy::missing_panics_doc,
438    clippy::module_name_repetitions
439)]
440#[inline]
441pub fn chroma_stft(
442    sample_rate: u32,
443    spectrum: &Array2<f32>, // shape: (window_length / 2 + 1, signal.len().div_ceil(hop_length))
444    n_fft: usize,
445    n_chroma: u32,
446    tuning: f32,
447) -> AnalysisResult<Array2<f32>> {
448    let mut raw_chroma = chroma_filter(sample_rate, n_fft, n_chroma, tuning)?;
449
450    raw_chroma = raw_chroma.dot(&spectrum.pow2());
451
452    // We want to maximize cache locality, and are iterating over columns,
453    // so let's make sure our array is in column-major order.
454    raw_chroma = raw_chroma
455        .to_shape((raw_chroma.dim(), Order::ColumnMajor))
456        .map_err_unlikely(|_| {
457            AnalysisError::AnalysisError(String::from("in chroma: failed to reorder array"))
458        })?
459        .to_owned();
460
461    Zip::from(raw_chroma.columns_mut()).for_each(|mut row| {
462        let sum = row.sum(); // we know that our values are positive, so no need to use abs
463        if sum >= f32::MIN_POSITIVE {
464            row /= sum;
465        }
466    });
467
468    Ok(raw_chroma)
469}
470
471#[cfg(test)]
472mod test {
473    use super::*;
474    use crate::{
475        SAMPLE_RATE,
476        decoder::{Decoder as _, MecompDecoder as Decoder},
477        utils::stft,
478    };
479    use ndarray::{Array2, arr1, arr2};
480    use ndarray_npy::ReadNpyExt as _;
481    use std::{fs::File, path::Path};
482
483    #[test]
484    fn test_chroma_interval_features() {
485        let file = File::open("data/chroma.npy").unwrap();
486        let chroma = Array2::<f64>::read_npy(file).unwrap().mapv(|x| x as f32);
487        let features = chroma_interval_features(&chroma);
488        let expected_features = arr1(&[
489            0.038_602_84,
490            0.021_852_81,
491            0.042_243_79,
492            0.063_852_78,
493            0.073_111_48,
494            0.025_125_66,
495            0.003_198_99,
496            0.003_113_08,
497            0.001_074_33,
498            0.002_418_61,
499        ]);
500        for (expected, actual) in expected_features.iter().zip(&features) {
501            // original test wanted 000_000_01
502            assert!(
503                0.000_000_1 > (expected - actual.abs()),
504                "{expected} !~= {actual}"
505            );
506        }
507    }
508
509    #[test]
510    fn test_extract_interval_features() {
511        let file = File::open("data/chroma-interval.npy").unwrap();
512        let chroma = Array2::<f64>::read_npy(file).unwrap().mapv(|x| x as f32);
513        let templates = arr2(&[
514            [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
515            [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
516            [0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
517            [0, 0, 1, 0, 0, 0, 0, 1, 1, 0],
518            [0, 0, 0, 1, 0, 0, 1, 0, 0, 1],
519            [0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
520            [0, 0, 0, 0, 0, 1, 0, 0, 1, 0],
521            [0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
522            [0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
523            [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
524            [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
525            [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
526        ]);
527
528        let file = File::open("data/interval-feature-matrix.npy").unwrap();
529        let expected_interval_features = Array2::<f64>::read_npy(file).unwrap().mapv(|x| x as f32);
530
531        let interval_features = extract_interval_features(&chroma, &templates);
532        for (expected, actual) in expected_interval_features
533            .iter()
534            .zip(interval_features.iter())
535        {
536            assert!(
537                0.000_000_1 > (expected - actual).abs(),
538                "{expected} !~= {actual}"
539            );
540        }
541    }
542
543    #[test]
544    fn test_normalize_feature_sequence() {
545        let array = arr2(&[[0.1, 0.3, 0.4], [1.1, 0.53, 1.01]]);
546        let expected_array = arr2(&[
547            [0.083_333_33, 0.361_445_78, 0.283_687_94],
548            [0.916_666_67, 0.638_554_22, 0.716_312_06],
549        ]);
550
551        let normalized_array = normalize_feature_sequence(&array);
552
553        assert!(!array.is_empty() && !expected_array.is_empty());
554
555        for (expected, actual) in normalized_array.iter().zip(expected_array.iter()) {
556            assert!(
557                0.000_000_1 > (expected - actual).abs(),
558                "{expected} !~= {actual}"
559            );
560        }
561    }
562
563    #[test]
564    fn test_chroma_desc() {
565        let song = Decoder::new()
566            .unwrap()
567            .decode(Path::new("data/s16_mono_22_5kHz.flac"))
568            .unwrap();
569        let mut chroma_desc = ChromaDesc::new(SAMPLE_RATE, 12);
570        chroma_desc.do_(&song.samples).unwrap();
571        let expected_values = [
572            -0.342_925_13,
573            -0.628_034_23,
574            -0.280_950_96,
575            0.086_864_59,
576            0.244_460_82,
577            -0.572_325_7,
578            0.232_920_65,
579            0.199_811_46,
580            -0.585_944_06,
581            -0.067_842_96,
582        ];
583        for (expected, actual) in expected_values.iter().zip(chroma_desc.get_value().iter()) {
584            // original test wanted absolute error < 0.0000001
585            let relative_error = (expected - actual).abs() / expected.abs();
586            assert!(
587                relative_error < 0.01,
588                "relative error: {relative_error}, expected: {expected}, actual: {actual}"
589            );
590        }
591    }
592
593    #[test]
594    fn test_chroma_stft_decode() {
595        let signal = Decoder::new()
596            .unwrap()
597            .decode(Path::new("data/s16_mono_22_5kHz.flac"))
598            .unwrap()
599            .samples;
600        let stft = stft(&signal, 8192, 2205);
601
602        let file = File::open("data/chroma.npy").unwrap();
603        let expected_chroma = Array2::<f64>::read_npy(file).unwrap().mapv(|x| x as f32);
604
605        let chroma = chroma_stft(22050, &stft, 8192, 12, -0.049_999_999_999_999_99).unwrap();
606
607        assert!(!chroma.is_empty() && !expected_chroma.is_empty());
608
609        for (expected, actual) in expected_chroma.iter().zip(chroma.iter()) {
610            // original test wanted absolute error < 0.0000001
611            let relative_error = (expected - actual).abs() / expected.abs();
612            assert!(
613                relative_error < 0.01,
614                "relative error: {relative_error}, expected: {expected}, actual: {actual}"
615            );
616        }
617    }
618
619    #[test]
620    fn test_estimate_tuning() {
621        let file = File::open("data/spectrum-chroma.npy").unwrap();
622        let arr = Array2::<f64>::read_npy(file).unwrap().mapv(|x| x as f32);
623
624        let tuning = estimate_tuning(22050, &arr, 2048, 0.01, 12).unwrap();
625        assert!(
626            0.000_001 > (-0.099_999_999_999_999_98 - tuning).abs(),
627            "{tuning} !~= -0.09999999999999998"
628        );
629    }
630
631    #[test]
632    fn test_chroma_estimate_tuning_empty_fix() {
633        assert!(0. == estimate_tuning(22050, &Array2::zeros((8192, 1)), 8192, 0.01, 12).unwrap());
634    }
635
636    #[test]
637    fn test_estimate_tuning_decode() {
638        let signal = Decoder::new()
639            .unwrap()
640            .decode(Path::new("data/s16_mono_22_5kHz.flac"))
641            .unwrap()
642            .samples;
643        let stft = stft(&signal, 8192, 2205);
644
645        let tuning = estimate_tuning(22050, &stft, 8192, 0.01, 12).unwrap();
646        assert!(
647            0.000_001 > (-0.049_999_999_999_999_99 - tuning).abs(),
648            "{tuning} !~= -0.04999999999999999"
649        );
650    }
651
652    #[test]
653    fn test_pitch_tuning() {
654        let file = File::open("data/pitch-tuning.npy").unwrap();
655        let pitch = Array1::<f64>::read_npy(file).unwrap();
656        #[allow(clippy::cast_possible_truncation)]
657        let pitch = pitch.mapv(|x| x as f32);
658
659        let tuned = pitch_tuning(pitch, 0.05, 12).unwrap();
660        assert!(f32::EPSILON > (tuned + 0.1).abs(), "{tuned} != -0.1");
661    }
662
663    #[test]
664    fn test_pitch_tuning_no_frequencies() {
665        let frequencies = arr1(&[]);
666        let tuned = pitch_tuning(frequencies, 0.05, 12).unwrap();
667        assert!(f32::EPSILON > tuned.abs(), "{tuned} != 0");
668    }
669
670    #[test]
671    fn test_pip_track() {
672        let file = File::open("data/spectrum-chroma.npy").unwrap();
673        let spectrum = Array2::<f64>::read_npy(file).unwrap().mapv(|x| x as f32);
674
675        let mags_file = File::open("data/spectrum-chroma-mags.npy").unwrap();
676        let expected_mags = Array1::<f64>::read_npy(mags_file)
677            .unwrap()
678            .mapv(|x| x as f32);
679
680        let pitches_file = File::open("data/spectrum-chroma-pitches.npy").unwrap();
681        let expected_pitches = Array1::<f64>::read_npy(pitches_file)
682            .unwrap()
683            .mapv(|x| x as f32);
684
685        let (mut pitches, mut mags) = pip_track(22050, &spectrum, 2048).unwrap();
686        pitches.sort_by(|a, b| a.partial_cmp(b).unwrap());
687        mags.sort_by(|a, b| a.partial_cmp(b).unwrap());
688
689        for (expected_pitches, actual_pitches) in expected_pitches.iter().zip(pitches.iter()) {
690            // original test wanted 000_000_01
691            assert!(
692                0.001 > (expected_pitches - actual_pitches).abs(),
693                "{expected_pitches} !~= {actual_pitches}"
694            );
695        }
696        for (expected_mags, actual_mags) in expected_mags.iter().zip(mags.iter()) {
697            // original test wanted 000_000_01
698            assert!(
699                0.001 > (expected_mags - actual_mags).abs(),
700                "{expected_mags} !~= {actual_mags}"
701            );
702        }
703    }
704
705    #[test]
706    fn test_chroma_filter() {
707        let file = File::open("data/chroma-filter.npy").unwrap();
708        let expected_filter = Array2::<f64>::read_npy(file).unwrap().mapv(|x| x as f32);
709
710        let filter = chroma_filter(22050, 2048, 12, -0.1).unwrap();
711
712        assert!(filter.iter().all(|&x| x > 0.));
713
714        for (expected, actual) in expected_filter.iter().zip(filter.iter()) {
715            // original test wanted 0.000_000_001
716            assert!(
717                0.000_1 > (expected - actual).abs(),
718                "{expected} !~= {actual}"
719            );
720        }
721    }
722
723    #[rstest::rstest]
724    // High 6 should be a major triad, 7 minor, 8 diminished and 9 augmented.
725    #[case::major_triad("data/chroma/Cmaj.ogg", 6)]
726    #[case::major_triad("data/chroma/Dmaj.ogg", 6)]
727    #[case::minor_triad("data/chroma/Cmin.ogg", 7)]
728    #[case::diminished_triad("data/chroma/Cdim.ogg", 8)]
729    #[case::augmented_triad("data/chroma/Caug.ogg", 9)]
730    fn test_end_result_triads(
731        #[case] path: &str,
732        #[case] expected_dominant_chroma_feature_index: usize,
733    ) {
734        let song = Decoder::new().unwrap().decode(Path::new(path)).unwrap();
735        let mut chroma_desc = ChromaDesc::new(SAMPLE_RATE, 12);
736        chroma_desc.do_(&song.samples).unwrap();
737        let chroma_values = chroma_desc.get_value();
738
739        let mut indices: Vec<usize> = (0..chroma_values.len()).collect();
740        indices.sort_by(|&i, &j| chroma_values[j].partial_cmp(&chroma_values[i]).unwrap());
741        assert!(indices[0] == expected_dominant_chroma_feature_index);
742        for (i, v) in chroma_values.into_iter().enumerate() {
743            if i >= 6 && i <= 10 {
744                if i == expected_dominant_chroma_feature_index {
745                    assert!(v > 0.8);
746                } else {
747                    assert!(v < 0.0);
748                }
749            }
750        }
751    }
752
753    #[test]
754    fn test_end_l2_norm_dyad() {
755        let song = Decoder::new()
756            .unwrap()
757            .decode(Path::new("data/chroma/dyad_tritone_IC6.ogg"))
758            .unwrap();
759        let mut chroma_desc = ChromaDesc::new(SAMPLE_RATE, 12);
760        chroma_desc.do_(&song.samples).unwrap();
761        let chroma_values = chroma_desc.get_value();
762        assert!(chroma_values[10] > 0.9);
763    }
764
765    #[test]
766    fn test_end_l2_norm_mode() {
767        let song = Decoder::new()
768            .unwrap()
769            .decode(Path::new("data/chroma/Cmaj_triads.ogg"))
770            .unwrap();
771        let mut chroma_desc = ChromaDesc::new(SAMPLE_RATE, 12);
772        chroma_desc.do_(&song.samples).unwrap();
773        let chroma_values = chroma_desc.get_value();
774        assert!(chroma_values[11] > 0.9);
775    }
776
777    #[test]
778    fn test_end_l2_norm_ratio() {
779        let song = Decoder::new()
780            .unwrap()
781            .decode(Path::new("data/chroma/triad_aug_maximize_ratio.ogg"))
782            .unwrap();
783        let mut chroma_desc = ChromaDesc::new(SAMPLE_RATE, 12);
784        chroma_desc.do_(&song.samples).unwrap();
785        let chroma_values = chroma_desc.get_value();
786        assert!(chroma_values[12] > 0.7);
787    }
788
789    #[rstest::rstest]
790    // Test all 12 intervals.
791    #[case::minor_second("data/chroma/minor_second.ogg", 0)]
792    #[case::major_second("data/chroma/major_second.ogg", 1)]
793    #[case::minor_third("data/chroma/minor_third.ogg", 2)]
794    #[case::major_third("data/chroma/major_third.ogg", 3)]
795    #[case::perfect_fourth("data/chroma/perfect_fourth.ogg", 4)]
796    #[case::tritone("data/chroma/tritone.ogg", 5)]
797    #[case::perfect_fifth("data/chroma/perfect_fifth.ogg", 4)]
798    #[case::minor_sixth("data/chroma/minor_sixth.ogg", 3)]
799    #[case::major_sixth("data/chroma/major_sixth.ogg", 2)]
800    #[case::minor_seventh("data/chroma/minor_seventh.ogg", 1)]
801    #[case::major_seventh("data/chroma/major_seventh.ogg", 0)]
802    fn test_end_result_intervals(
803        #[case] path: &str,
804        #[case] expected_dominant_chroma_feature_index: usize,
805    ) {
806        let song = Decoder::new().unwrap().decode(Path::new(path)).unwrap();
807        let mut chroma_desc = ChromaDesc::new(SAMPLE_RATE, 12);
808        chroma_desc.do_(&song.samples).unwrap();
809        let chroma_values = chroma_desc.get_value();
810
811        let mut indices: Vec<usize> = (0..chroma_values.len()).collect();
812        indices.sort_by(|&i, &j| chroma_values[j].partial_cmp(&chroma_values[i]).unwrap());
813        assert_eq!(indices[0], expected_dominant_chroma_feature_index);
814        for (i, v) in chroma_values.into_iter().enumerate() {
815            if i < 6 {
816                if i == expected_dominant_chroma_feature_index {
817                    assert!(v > 0.9);
818                } else {
819                    assert!(v < 0.0);
820                }
821            }
822        }
823    }
824}