Skip to main content

math_audio_dsp/audio_features/
utils.rs

1//! Utility functions for audio feature extraction.
2//!
3//! Ported from bliss-audio utils.rs — pure Rust implementations of
4//! mean, geometric mean, normalization, zero-crossing, STFT, etc.
5
6use ndarray::{Array, Array1, Array2, arr1, s};
7use rustfft::FftPlanner;
8use rustfft::num_complex::Complex;
9
10use crate::analysis::plan_fft_forward;
11use crate::stft::generate_hann_window;
12
13/// Normalize a value from [min, max] to [-1, 1].
14pub fn normalize(value: f32, min_value: f32, max_value: f32) -> f32 {
15    2. * (value - min_value) / (max_value - min_value) - 1.
16}
17
18/// Arithmetic mean of a slice.
19pub fn mean(input: &[f32]) -> f32 {
20    if input.is_empty() {
21        return 0.0;
22    }
23    input.iter().sum::<f32>() / input.len() as f32
24}
25
26/// Optimized geometric mean (from bliss, courtesy of Jacques-Henri Jourdan).
27/// Only works for input of size a multiple of 8, with values in [0, 2^65].
28pub fn geometric_mean(input: &[f32]) -> f32 {
29    let mut exponents: i32 = 0;
30    let mut mantissas: f64 = 1.;
31    for ch in input.chunks_exact(8) {
32        let mut m = (ch[0] as f64 * ch[1] as f64) * (ch[2] as f64 * ch[3] as f64);
33        m *= 3.273390607896142e150; // 2^500 : avoid underflows and denormals
34        m *= (ch[4] as f64 * ch[5] as f64) * (ch[6] as f64 * ch[7] as f64);
35        if m == 0. {
36            return 0.;
37        }
38        exponents += (m.to_bits() >> 52) as i32;
39        mantissas *= f64::from_bits((m.to_bits() & 0xFFFFFFFFFFFFF) | 0x3FF0000000000000);
40    }
41
42    let n = input.len() as u32;
43    (((mantissas as f32).log2() + exponents as f32) / n as f32 - (1023. + 500.) / 8.).exp2()
44}
45
46/// Count zero-crossings in a signal (Essentia algorithm).
47pub fn number_crossings(input: &[f32]) -> u32 {
48    if input.is_empty() {
49        return 0;
50    }
51    let mut crossings = 0u32;
52    let mut was_positive = input[0] > 0.;
53
54    for &sample in input {
55        let is_positive = sample > 0.;
56        if was_positive != is_positive {
57            crossings += 1;
58            was_positive = is_positive;
59        }
60    }
61    crossings
62}
63
64/// Reflect-pad an array (mirror boundary conditions).
65pub fn reflect_pad(array: &[f32], pad: usize) -> Vec<f32> {
66    let prefix = array[1..=pad].iter().rev().copied().collect::<Vec<f32>>();
67    let suffix = array[(array.len() - 2) - pad + 1..array.len() - 1]
68        .iter()
69        .rev()
70        .copied()
71        .collect::<Vec<f32>>();
72    let mut output = Vec::with_capacity(prefix.len() + array.len() + suffix.len());
73    output.extend(prefix);
74    output.extend(array);
75    output.extend(suffix);
76    output
77}
78
79/// Short-time Fourier transform with Hann window.
80pub fn stft(signal: &[f32], window_length: usize, hop_length: usize) -> Array2<f64> {
81    let mut stft = Array2::zeros((
82        (signal.len() as f32 / hop_length as f32).ceil() as usize,
83        window_length / 2 + 1,
84    ));
85    let signal = reflect_pad(signal, window_length / 2);
86
87    // Periodic Hann window
88    let hann_window = Array::from_vec(generate_hann_window(window_length));
89
90    let fft = plan_fft_forward(window_length);
91
92    for (window, mut stft_col) in signal
93        .windows(window_length)
94        .step_by(hop_length)
95        .zip(stft.rows_mut())
96    {
97        let mut fft_input = (arr1(window) * &hann_window).mapv(|x| Complex::new(x, 0.));
98        match fft_input.as_slice_mut() {
99            Some(s) => fft.process(s),
100            None => {
101                fft.process(&mut fft_input.to_vec());
102            }
103        };
104        stft_col.assign(
105            &fft_input
106                .slice(s![..window_length / 2 + 1])
107                .mapv(|x| (x.re * x.re + x.im * x.im).sqrt() as f64),
108        );
109    }
110    stft.permuted_axes((1, 0))
111}
112
113/// Convert Hz frequencies to fractional octaves (in-place).
114pub fn hz_to_octs_inplace(
115    frequencies: &mut Array1<f64>,
116    tuning: f64,
117    bins_per_octave: u32,
118) -> &mut Array1<f64> {
119    let a440 = 440.0 * 2_f64.powf(tuning / f64::from(bins_per_octave));
120    *frequencies /= a440 / 16.;
121    frequencies.mapv_inplace(f64::log2);
122    frequencies
123}
124
125/// FFT-based convolution (same-size output).
126pub fn convolve(input: &Array1<f64>, kernel: &Array1<f64>) -> Array1<f64> {
127    let mut common_length = input.len() + kernel.len();
128    if !common_length.is_multiple_of(2) {
129        common_length -= 1;
130    }
131    let mut padded_input = Array::from_elem(
132        common_length,
133        Complex {
134            re: 0.0_f64,
135            im: 0.0,
136        },
137    );
138    padded_input
139        .slice_mut(s![..input.len()])
140        .assign(&input.mapv(|x| Complex::new(x, 0.)));
141    let mut padded_kernel = Array::from_elem(
142        common_length,
143        Complex {
144            re: 0.0_f64,
145            im: 0.0,
146        },
147    );
148    padded_kernel
149        .slice_mut(s![..kernel.len()])
150        .assign(&kernel.mapv(|x| Complex::new(x, 0.)));
151
152    let mut planner = FftPlanner::new();
153    let forward = planner.plan_fft_forward(common_length);
154    forward.process(padded_input.as_slice_mut().unwrap());
155    forward.process(padded_kernel.as_slice_mut().unwrap());
156
157    let mut multiplication = padded_input * padded_kernel;
158
159    let back = planner.plan_fft_inverse(common_length);
160    back.process(multiplication.as_slice_mut().unwrap());
161
162    let multiplication_length = multiplication.len() as f64;
163    let multiplication = multiplication
164        .slice_move(s![
165            (kernel.len() - 1) / 2..(kernel.len() - 1) / 2 + input.len()
166        ])
167        .mapv(|x| x.re);
168    multiplication / multiplication_length
169}
170
171/// Standard deviation of a slice of f32 values.
172pub fn std_deviation(values: &[f32]) -> f32 {
173    if values.len() <= 1 {
174        return 0.0;
175    }
176    let m = mean(values);
177    let variance = values.iter().map(|&x| (x - m) * (x - m)).sum::<f32>() / values.len() as f32;
178    variance.sqrt()
179}
180
181#[cfg(test)]
182mod tests {
183    use super::*;
184
185    #[test]
186    fn test_mean() {
187        let numbers = vec![0.0, 1.0, 2.0, 3.0, 4.0];
188        assert_eq!(2.0, mean(&numbers));
189    }
190
191    #[test]
192    fn test_geometric_mean() {
193        let numbers = vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0];
194        assert_eq!(0.0, geometric_mean(&numbers));
195
196        let numbers = vec![4.0, 2.0, 1.0, 4.0, 2.0, 1.0, 2.0, 2.0];
197        assert!(0.0001 > (2.0 - geometric_mean(&numbers)).abs());
198    }
199
200    #[test]
201    fn test_number_crossings() {
202        let input = vec![-1.0, 1.0, -1.0, 1.0];
203        assert_eq!(3, number_crossings(&input));
204
205        let input = vec![1.0, 1.0, 1.0];
206        assert_eq!(0, number_crossings(&input));
207    }
208
209    #[test]
210    fn test_normalize() {
211        assert!((0.0 - normalize(0.5, 0.0, 1.0)).abs() < 1e-6);
212        assert!((-1.0 - normalize(0.0, 0.0, 1.0)).abs() < 1e-6);
213        assert!((1.0 - normalize(1.0, 0.0, 1.0)).abs() < 1e-6);
214    }
215
216    #[test]
217    fn test_std_deviation() {
218        let values = vec![2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0];
219        assert!((2.0 - std_deviation(&values)).abs() < 0.01);
220    }
221
222    #[test]
223    fn test_reflect_pad() {
224        let array: Vec<f32> = (0..100).map(|x| x as f32).collect();
225        let output = reflect_pad(&array, 3);
226        assert_eq!(&output[..4], &[3.0, 2.0, 1.0, 0.0]);
227        assert_eq!(&output[3..103], &array[..]);
228        assert_eq!(&output[103..106], &[98.0, 97.0, 96.0]);
229    }
230}