Skip to main content

math_audio_dsp/audio_features/
utils.rs

1//! Utility functions for audio feature extraction.
2//!
3//! Ported from bliss-audio utils.rs — pure Rust implementations of
4//! mean, geometric mean, normalization, zero-crossing, STFT, etc.
5
6use ndarray::{Array, Array1, Array2, arr1, s};
7use rustfft::FftPlanner;
8use rustfft::num_complex::Complex;
9
10use crate::analysis::plan_fft_forward;
11use crate::stft::generate_hann_window;
12
13/// Normalize a value from [min, max] to [-1, 1].
14pub fn normalize(value: f32, min_value: f32, max_value: f32) -> f32 {
15    2. * (value - min_value) / (max_value - min_value) - 1.
16}
17
18/// Arithmetic mean of a slice.
19pub fn mean(input: &[f32]) -> f32 {
20    if input.is_empty() {
21        return 0.0;
22    }
23    input.iter().sum::<f32>() / input.len() as f32
24}
25
26/// Optimized geometric mean (from bliss, courtesy of Jacques-Henri Jourdan).
27/// Only works for input of size a multiple of 8, with values in [0, 2^65].
28pub fn geometric_mean(input: &[f32]) -> f32 {
29    assert!(
30        input.len().is_multiple_of(8),
31        "geometric_mean input length must be a multiple of 8, got {}",
32        input.len()
33    );
34    let mut exponents: i32 = 0;
35    let mut mantissas: f64 = 1.;
36    for ch in input.chunks_exact(8) {
37        let mut m = (ch[0] as f64 * ch[1] as f64) * (ch[2] as f64 * ch[3] as f64);
38        m *= 3.273390607896142e150; // 2^500 : avoid underflows and denormals
39        m *= (ch[4] as f64 * ch[5] as f64) * (ch[6] as f64 * ch[7] as f64);
40        if m == 0. {
41            return 0.;
42        }
43        exponents += (m.to_bits() >> 52) as i32;
44        mantissas *= f64::from_bits((m.to_bits() & 0xFFFFFFFFFFFFF) | 0x3FF0000000000000);
45    }
46
47    let n = input.len() as u32;
48    (((mantissas as f32).log2() + exponents as f32) / n as f32 - (1023. + 500.) / 8.).exp2()
49}
50
51/// Count zero-crossings in a signal (Essentia algorithm).
52pub fn number_crossings(input: &[f32]) -> u32 {
53    if input.is_empty() {
54        return 0;
55    }
56    let mut crossings = 0u32;
57    let mut was_positive = input[0] > 0.;
58
59    for &sample in input {
60        let is_positive = sample > 0.;
61        if was_positive != is_positive {
62            crossings += 1;
63            was_positive = is_positive;
64        }
65    }
66    crossings
67}
68
69/// Reflect-pad an array (mirror boundary conditions).
70pub fn reflect_pad(array: &[f32], pad: usize) -> Vec<f32> {
71    let prefix = array[1..=pad].iter().rev().copied().collect::<Vec<f32>>();
72    let suffix = array[(array.len() - 2) - pad + 1..array.len() - 1]
73        .iter()
74        .rev()
75        .copied()
76        .collect::<Vec<f32>>();
77    let mut output = Vec::with_capacity(prefix.len() + array.len() + suffix.len());
78    output.extend(prefix);
79    output.extend(array);
80    output.extend(suffix);
81    output
82}
83
84/// Short-time Fourier transform with Hann window.
85pub fn stft(signal: &[f32], window_length: usize, hop_length: usize) -> Array2<f64> {
86    let mut stft = Array2::zeros((
87        (signal.len() as f32 / hop_length as f32).ceil() as usize,
88        window_length / 2 + 1,
89    ));
90    let signal = reflect_pad(signal, window_length / 2);
91
92    // Periodic Hann window
93    let hann_window = Array::from_vec(generate_hann_window(window_length));
94
95    let fft = plan_fft_forward(window_length);
96
97    for (window, mut stft_col) in signal
98        .windows(window_length)
99        .step_by(hop_length)
100        .zip(stft.rows_mut())
101    {
102        let mut fft_input = (arr1(window) * &hann_window).mapv(|x| Complex::new(x, 0.));
103        match fft_input.as_slice_mut() {
104            Some(s) => fft.process(s),
105            None => {
106                fft.process(&mut fft_input.to_vec());
107            }
108        };
109        stft_col.assign(
110            &fft_input
111                .slice(s![..window_length / 2 + 1])
112                .mapv(|x| (x.re * x.re + x.im * x.im).sqrt() as f64),
113        );
114    }
115    stft.permuted_axes((1, 0))
116}
117
118/// Convert Hz frequencies to fractional octaves (in-place).
119pub fn hz_to_octs_inplace(
120    frequencies: &mut Array1<f64>,
121    tuning: f64,
122    bins_per_octave: u32,
123) -> &mut Array1<f64> {
124    let a440 = 440.0 * 2_f64.powf(tuning / f64::from(bins_per_octave));
125    *frequencies /= a440 / 16.;
126    frequencies.mapv_inplace(f64::log2);
127    frequencies
128}
129
130/// FFT-based convolution (same-size output).
131pub fn convolve(input: &Array1<f64>, kernel: &Array1<f64>) -> Array1<f64> {
132    let mut common_length = input.len() + kernel.len();
133    if !common_length.is_multiple_of(2) {
134        common_length -= 1;
135    }
136    let mut padded_input = Array::from_elem(
137        common_length,
138        Complex {
139            re: 0.0_f64,
140            im: 0.0,
141        },
142    );
143    padded_input
144        .slice_mut(s![..input.len()])
145        .assign(&input.mapv(|x| Complex::new(x, 0.)));
146    let mut padded_kernel = Array::from_elem(
147        common_length,
148        Complex {
149            re: 0.0_f64,
150            im: 0.0,
151        },
152    );
153    padded_kernel
154        .slice_mut(s![..kernel.len()])
155        .assign(&kernel.mapv(|x| Complex::new(x, 0.)));
156
157    let mut planner = FftPlanner::new();
158    let forward = planner.plan_fft_forward(common_length);
159    forward.process(padded_input.as_slice_mut().unwrap());
160    forward.process(padded_kernel.as_slice_mut().unwrap());
161
162    let mut multiplication = padded_input * padded_kernel;
163
164    let back = planner.plan_fft_inverse(common_length);
165    back.process(multiplication.as_slice_mut().unwrap());
166
167    let multiplication_length = multiplication.len() as f64;
168    let multiplication = multiplication
169        .slice_move(s![
170            (kernel.len() - 1) / 2..(kernel.len() - 1) / 2 + input.len()
171        ])
172        .mapv(|x| x.re);
173    multiplication / multiplication_length
174}
175
176/// Standard deviation of a slice of f32 values.
177pub fn std_deviation(values: &[f32]) -> f32 {
178    if values.len() <= 1 {
179        return 0.0;
180    }
181    let m = mean(values);
182    let variance = values.iter().map(|&x| (x - m) * (x - m)).sum::<f32>() / values.len() as f32;
183    variance.sqrt()
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189
190    #[test]
191    fn test_mean() {
192        let numbers = vec![0.0, 1.0, 2.0, 3.0, 4.0];
193        assert_eq!(2.0, mean(&numbers));
194    }
195
196    #[test]
197    fn test_geometric_mean() {
198        let numbers = vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0];
199        assert_eq!(0.0, geometric_mean(&numbers));
200
201        let numbers = vec![4.0, 2.0, 1.0, 4.0, 2.0, 1.0, 2.0, 2.0];
202        assert!(0.0001 > (2.0 - geometric_mean(&numbers)).abs());
203    }
204
205    #[test]
206    #[should_panic(expected = "geometric_mean input length must be a multiple of 8")]
207    fn test_geometric_mean_panics_on_non_multiple_of_8() {
208        // Issue #3: chunks_exact(8) silently drops the remainder, producing
209        // a wrong result when the length is not a multiple of 8.
210        let numbers = vec![1.0f32; 9];
211        geometric_mean(&numbers);
212    }
213
214    #[test]
215    fn test_number_crossings() {
216        let input = vec![-1.0, 1.0, -1.0, 1.0];
217        assert_eq!(3, number_crossings(&input));
218
219        let input = vec![1.0, 1.0, 1.0];
220        assert_eq!(0, number_crossings(&input));
221    }
222
223    #[test]
224    fn test_normalize() {
225        assert!((0.0 - normalize(0.5, 0.0, 1.0)).abs() < 1e-6);
226        assert!((-1.0 - normalize(0.0, 0.0, 1.0)).abs() < 1e-6);
227        assert!((1.0 - normalize(1.0, 0.0, 1.0)).abs() < 1e-6);
228    }
229
230    #[test]
231    fn test_std_deviation() {
232        let values = vec![2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0];
233        assert!((2.0 - std_deviation(&values)).abs() < 0.01);
234    }
235
236    #[test]
237    fn test_reflect_pad() {
238        let array: Vec<f32> = (0..100).map(|x| x as f32).collect();
239        let output = reflect_pad(&array, 3);
240        assert_eq!(&output[..4], &[3.0, 2.0, 1.0, 0.0]);
241        assert_eq!(&output[3..103], &array[..]);
242        assert_eq!(&output[103..106], &[98.0, 97.0, 96.0]);
243    }
244}