Skip to main content

math_audio_dsp/audio_features/
chroma.rs

1//! Chroma feature extraction.
2//!
3//! Ported from bliss-audio chroma.rs — already pure Rust, no aubio dependency.
4//! Computes 13 interval/triad features from the chromagram.
5
6use super::utils::{hz_to_octs_inplace, normalize, stft};
7use ndarray::{Array, Array1, Array2, Axis, Zip, arr1, arr2, s};
8use oxiblas_ndarray::blas::{dot_view, matmul};
9
10const WINDOW_SIZE: usize = 8192;
11const MAX_VALUE: f32 = 1.0;
12const MIN_VALUE: f32 = 0.0;
13const MAX_L2_INTERVAL: f32 = 0.25;
14const MAX_L2_TRIAD: f32 = 0.025;
15const MAX_TRIAD_INTERVAL_RATIO: f32 = std::f32::consts::FRAC_PI_2;
16
17/// Error type for chroma analysis.
18#[derive(Debug, Clone)]
19pub struct ChromaError(pub String);
20
21impl std::fmt::Display for ChromaError {
22    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
23        write!(f, "chroma error: {}", self.0)
24    }
25}
26
27impl std::error::Error for ChromaError {}
28
29/// Compute 13 chroma interval features from the full song samples.
30///
31/// Returns a Vec of 13 normalized features (6 interval classes + 4 triads + 2 L2 norms + 1 ratio).
32pub fn compute_chroma_features(samples: &[f32], sample_rate: u32) -> Result<Vec<f32>, ChromaError> {
33    let n_chroma = 12u32;
34
35    let mut spectrum = stft(samples, WINDOW_SIZE, 2205);
36    let tuning = estimate_tuning(sample_rate, &spectrum, WINDOW_SIZE, 0.01, 12)?;
37    let chroma = chroma_stft(sample_rate, &mut spectrum, WINDOW_SIZE, n_chroma, tuning)?;
38
39    let mut raw_features = chroma_interval_features(&chroma)?;
40
41    let (mut interval_class, mut interval_class_mode) =
42        raw_features.view_mut().split_at(Axis(0), 6);
43
44    let l2_norm_interval_class = dot_view(&interval_class.view(), &interval_class.view()).sqrt();
45    let l2_norm_interval_class_mode =
46        dot_view(&interval_class_mode.view(), &interval_class_mode.view()).sqrt();
47
48    if l2_norm_interval_class > 0. {
49        interval_class /= l2_norm_interval_class;
50    }
51    if l2_norm_interval_class_mode > 0. {
52        interval_class_mode /= l2_norm_interval_class_mode;
53    }
54
55    let mut features: Vec<f32> = raw_features
56        .mapv_into_any(|x| normalize(x as f32, MIN_VALUE, MAX_VALUE))
57        .to_vec();
58
59    let normalized_l2_norm_interval_class =
60        (2. * (l2_norm_interval_class as f32 - 0.) / (MAX_L2_INTERVAL - 0.) - 1.).min(1.);
61    features.push(normalized_l2_norm_interval_class);
62
63    let normalized_l2_norm_interval_class_mode =
64        (2. * (l2_norm_interval_class_mode as f32 - 0.) / (MAX_L2_TRIAD - 0.) - 1.).min(1.);
65    features.push(normalized_l2_norm_interval_class_mode);
66
67    let angle = (20. * l2_norm_interval_class_mode).atan2(l2_norm_interval_class + 1e-12_f64);
68    let normalized_ratio = 2. * (angle as f32 - 0.) / (MAX_TRIAD_INTERVAL_RATIO - 0.) - 1.;
69    features.push(normalized_ratio);
70
71    Ok(features)
72}
73
74fn chroma_interval_features(chroma: &Array2<f64>) -> Result<Array1<f64>, ChromaError> {
75    let chroma = normalize_feature_sequence(&chroma.mapv(|x| (x * 15.).exp()));
76    let templates = arr2(&[
77        [1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
78        [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
79        [0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
80        [0, 0, 1, 0, 0, 0, 0, 1, 1, 0],
81        [0, 0, 0, 1, 0, 0, 1, 0, 0, 1],
82        [0, 0, 0, 0, 1, 0, 0, 0, 0, 0],
83        [0, 0, 0, 0, 0, 1, 0, 0, 1, 0],
84        [0, 0, 0, 0, 0, 0, 1, 1, 0, 0],
85        [0, 0, 0, 0, 0, 0, 0, 0, 0, 1],
86        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
87        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
88        [0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
89    ]);
90    let interval_feature_matrix = extract_interval_features(&chroma, &templates);
91    interval_feature_matrix.mean_axis(Axis(1)).ok_or_else(|| {
92        ChromaError("Tried to run chroma on empty array. Need at least one sample.".to_string())
93    })
94}
95
96fn extract_interval_features(chroma: &Array2<f64>, templates: &Array2<i32>) -> Array2<f64> {
97    let mut f_intervals: Array2<f64> = Array::zeros((chroma.shape()[1], templates.shape()[1]));
98    for (template, mut f_interval) in templates
99        .axis_iter(Axis(1))
100        .zip(f_intervals.axis_iter_mut(Axis(1)))
101    {
102        for shift in 0..12 {
103            let mut vec: Vec<i32> = template.to_vec();
104            vec.rotate_right(shift);
105            let rolled = arr1(&vec);
106            let power = Zip::from(chroma.t())
107                .and_broadcast(&rolled)
108                .map_collect(|&f, &s| f.powi(s))
109                .map_axis_mut(Axis(1), |x| x.product());
110            f_interval += &power;
111        }
112    }
113    f_intervals.t().to_owned()
114}
115
116fn normalize_feature_sequence(feature: &Array2<f64>) -> Array2<f64> {
117    let mut normalized_sequence = feature.to_owned();
118    for mut column in normalized_sequence.columns_mut() {
119        let mut sum = column.mapv(|x| x.abs()).sum();
120        if sum < 0.0001 {
121            sum = 1.;
122        }
123        column /= sum;
124    }
125    normalized_sequence
126}
127
128fn chroma_filter(
129    sample_rate: u32,
130    n_fft: usize,
131    n_chroma: u32,
132    tuning: f64,
133) -> Result<Array2<f64>, ChromaError> {
134    let ctroct = 5.0;
135    let octwidth = 2.;
136    let n_chroma_float = f64::from(n_chroma);
137    let n_chroma2 = (n_chroma_float / 2.0).round() as u32;
138    let n_chroma2_float = f64::from(n_chroma2);
139
140    let frequencies = Array::linspace(0., f64::from(sample_rate), n_fft + 1);
141
142    let mut freq_bins = frequencies;
143    hz_to_octs_inplace(&mut freq_bins, tuning, n_chroma);
144    freq_bins.mapv_inplace(|x| x * n_chroma_float);
145    freq_bins[0] = freq_bins[1] - 1.5 * n_chroma_float;
146
147    let mut binwidth_bins = Array::ones(freq_bins.raw_dim());
148    binwidth_bins.slice_mut(s![0..freq_bins.len() - 1]).assign(
149        &(&freq_bins.slice(s![1..]) - &freq_bins.slice(s![..-1]))
150            .mapv(|x| if x <= 1. { 1. } else { x }),
151    );
152
153    let mut d: Array2<f64> = Array::zeros((n_chroma as usize, freq_bins.len()));
154    for (idx, mut row) in d.rows_mut().into_iter().enumerate() {
155        row.fill(idx as f64);
156    }
157    d = -d + &freq_bins;
158
159    d.mapv_inplace(|x| {
160        (x + n_chroma2_float + 10. * n_chroma_float) % n_chroma_float - n_chroma2_float
161    });
162    d = d / binwidth_bins;
163    d.mapv_inplace(|x| (-0.5 * (2. * x) * (2. * x)).exp());
164
165    let mut wts = d;
166    for mut col in wts.columns_mut() {
167        let mut sum = col.mapv(|x| x * x).sum().sqrt();
168        if sum < f64::MIN_POSITIVE {
169            sum = 1.;
170        }
171        col /= sum;
172    }
173
174    freq_bins.mapv_inplace(|x| (-0.5 * ((x / n_chroma_float - ctroct) / octwidth).powi(2)).exp());
175    wts *= &freq_bins;
176
177    // np.roll by -3
178    let mut b = Array2::zeros(wts.dim());
179    b.slice_mut(s![-3.., ..]).assign(&wts.slice(s![..3, ..]));
180    b.slice_mut(s![..-3, ..]).assign(&wts.slice(s![3.., ..]));
181
182    wts = b;
183    let non_aliased = 1 + n_fft / 2;
184    Ok(wts.slice_move(s![.., ..non_aliased]))
185}
186
187fn pip_track(
188    sample_rate: u32,
189    spectrum: &Array2<f64>,
190    n_fft: usize,
191) -> Result<(Vec<f64>, Vec<f64>), ChromaError> {
192    let sample_rate_float = f64::from(sample_rate);
193    let fmin = 150.0_f64;
194    let fmax = 4000.0_f64.min(sample_rate_float / 2.0);
195    let threshold = 0.1;
196
197    let fft_freqs = Array::linspace(0., sample_rate_float / 2., 1 + n_fft / 2);
198
199    let length = spectrum.len_of(Axis(0));
200
201    let freq_mask: Vec<bool> = fft_freqs
202        .iter()
203        .map(|&f| (fmin <= f) && (f < fmax))
204        .collect();
205
206    let ref_value = spectrum.map_axis(Axis(0), |x| {
207        let first: f64 = *x.first().expect("empty spectrum axis");
208        x.fold(first, |acc, &elem| if acc > elem { acc } else { elem }) * threshold
209    });
210
211    let taken_columns = freq_mask.iter().filter(|&&x| x).count();
212    let mut pitches = Vec::with_capacity(taken_columns * length);
213    let mut mags = Vec::with_capacity(taken_columns * length);
214
215    let beginning = freq_mask
216        .iter()
217        .position(|&b| b)
218        .ok_or_else(|| ChromaError("in pip_track: no freq mask".to_string()))?;
219    let end = freq_mask
220        .iter()
221        .rposition(|&b| b)
222        .ok_or_else(|| ChromaError("in pip_track: no freq mask".to_string()))?;
223
224    let zipped = Zip::indexed(spectrum.slice(s![beginning..end - 3, ..]))
225        .and(spectrum.slice(s![beginning + 1..end - 2, ..]))
226        .and(spectrum.slice(s![beginning + 2..end - 1, ..]));
227
228    zipped.for_each(|(i, j), &before_elem, &elem, &after_elem| {
229        if elem > ref_value[j] && after_elem <= elem && before_elem < elem {
230            let avg = 0.5 * (after_elem - before_elem);
231            let mut shift = 2. * elem - after_elem - before_elem;
232            if shift.abs() < f64::MIN_POSITIVE {
233                shift += 1.;
234            }
235            shift = avg / shift;
236            pitches.push(((i + beginning + 1) as f64 + shift) * sample_rate_float / n_fft as f64);
237            mags.push(elem + 0.5 * avg * shift);
238        }
239    });
240
241    Ok((pitches, mags))
242}
243
244fn pitch_tuning(
245    frequencies: &mut Array1<f64>,
246    resolution: f64,
247    bins_per_octave: u32,
248) -> Result<f64, ChromaError> {
249    if frequencies.is_empty() {
250        return Ok(0.0);
251    }
252    hz_to_octs_inplace(frequencies, 0.0, 12);
253    frequencies.mapv_inplace(|x| f64::from(bins_per_octave) * x % 1.0);
254    frequencies.mapv_inplace(|x| if x >= 0.5 { x - 1. } else { x });
255
256    let indexes = ((frequencies.to_owned() - -0.5) / resolution).mapv(|x| x as usize);
257    let mut counts: Array1<usize> = Array::zeros(((0.5 - -0.5) / resolution) as usize);
258    for &idx in indexes.iter() {
259        if idx < counts.len() {
260            counts[idx] += 1;
261        }
262    }
263    let max_index = counts
264        .iter()
265        .enumerate()
266        .max_by_key(|&(_, v)| *v)
267        .map(|(i, _)| i)
268        .ok_or_else(|| ChromaError("empty counts in pitch_tuning".to_string()))?;
269
270    Ok((-50. + (100. * resolution * max_index as f64)) / 100.)
271}
272
273fn estimate_tuning(
274    sample_rate: u32,
275    spectrum: &Array2<f64>,
276    n_fft: usize,
277    resolution: f64,
278    bins_per_octave: u32,
279) -> Result<f64, ChromaError> {
280    let (pitch, mag) = pip_track(sample_rate, spectrum, n_fft)?;
281
282    let (filtered_pitch, filtered_mag): (Vec<f64>, Vec<f64>) = pitch
283        .iter()
284        .zip(&mag)
285        .filter(|&(&p, _)| p > 0.)
286        .map(|(x, y)| (*x, *y))
287        .unzip();
288
289    if filtered_pitch.is_empty() {
290        return Ok(0.);
291    }
292
293    // Compute median of magnitudes
294    let mut sorted_mags = filtered_mag.clone();
295    sorted_mags.sort_by(|a, b| a.partial_cmp(b).unwrap());
296    let threshold = if sorted_mags.len() % 2 == 0 {
297        (sorted_mags[sorted_mags.len() / 2 - 1] + sorted_mags[sorted_mags.len() / 2]) / 2.0
298    } else {
299        sorted_mags[sorted_mags.len() / 2]
300    };
301
302    let mut pitch_arr: Array1<f64> = filtered_pitch
303        .iter()
304        .zip(&filtered_mag)
305        .filter_map(|(&p, &m)| if m >= threshold { Some(p) } else { None })
306        .collect::<Vec<f64>>()
307        .into();
308
309    pitch_tuning(&mut pitch_arr, resolution, bins_per_octave)
310}
311
312fn chroma_stft(
313    sample_rate: u32,
314    spectrum: &mut Array2<f64>,
315    n_fft: usize,
316    n_chroma: u32,
317    tuning: f64,
318) -> Result<Array2<f64>, ChromaError> {
319    spectrum.par_mapv_inplace(|x| x * x);
320    let mut raw_chroma = chroma_filter(sample_rate, n_fft, n_chroma, tuning)?;
321
322    raw_chroma = matmul(&raw_chroma, spectrum);
323    for mut row in raw_chroma.columns_mut() {
324        let mut sum = row.mapv(|x| x.abs()).sum();
325        if sum < f64::MIN_POSITIVE {
326            sum = 1.;
327        }
328        row /= sum;
329    }
330    Ok(raw_chroma)
331}
332
333#[cfg(test)]
334mod tests {
335    use super::*;
336
337    #[test]
338    fn test_chroma_features_length() {
339        // Generate a simple tone
340        let sr = 22050u32;
341        let duration = 5.0;
342        let n = (sr as f32 * duration) as usize;
343        let signal: Vec<f32> = (0..n)
344            .map(|i| (2.0 * std::f32::consts::PI * 440.0 * i as f32 / sr as f32).sin())
345            .collect();
346
347        let features = compute_chroma_features(&signal, sr).unwrap();
348        assert_eq!(features.len(), 13);
349    }
350
351    #[test]
352    fn test_normalize_feature_sequence_basic() {
353        let array = arr2(&[[0.1, 0.3, 0.4, 0.], [1.1, 0.53, 1.01, 0.]]);
354        let expected = arr2(&[
355            [0.08333333, 0.36144578, 0.28368794, 0.],
356            [0.91666667, 0.63855422, 0.71631206, 0.],
357        ]);
358
359        let normalized = normalize_feature_sequence(&array);
360
361        for (expected, actual) in normalized.iter().zip(expected.iter()) {
362            assert!((expected - actual).abs() < 1e-6);
363        }
364    }
365}