math_audio_dsp/audio_features/
utils.rs1use ndarray::{Array, Array1, Array2, arr1, s};
7use rustfft::FftPlanner;
8use rustfft::num_complex::Complex;
9
10use crate::analysis::plan_fft_forward;
11use crate::stft::generate_hann_window;
12
13pub fn normalize(value: f32, min_value: f32, max_value: f32) -> f32 {
15 2. * (value - min_value) / (max_value - min_value) - 1.
16}
17
18pub fn mean(input: &[f32]) -> f32 {
20 if input.is_empty() {
21 return 0.0;
22 }
23 input.iter().sum::<f32>() / input.len() as f32
24}
25
26pub fn geometric_mean(input: &[f32]) -> f32 {
29 assert!(
30 input.len().is_multiple_of(8),
31 "geometric_mean input length must be a multiple of 8, got {}",
32 input.len()
33 );
34 let mut exponents: i32 = 0;
35 let mut mantissas: f64 = 1.;
36 for ch in input.chunks_exact(8) {
37 let mut m = (ch[0] as f64 * ch[1] as f64) * (ch[2] as f64 * ch[3] as f64);
38 m *= 3.273390607896142e150; m *= (ch[4] as f64 * ch[5] as f64) * (ch[6] as f64 * ch[7] as f64);
40 if m == 0. {
41 return 0.;
42 }
43 exponents += (m.to_bits() >> 52) as i32;
44 mantissas *= f64::from_bits((m.to_bits() & 0xFFFFFFFFFFFFF) | 0x3FF0000000000000);
45 }
46
47 let n = input.len() as u32;
48 (((mantissas as f32).log2() + exponents as f32) / n as f32 - (1023. + 500.) / 8.).exp2()
49}
50
51pub fn number_crossings(input: &[f32]) -> u32 {
53 if input.is_empty() {
54 return 0;
55 }
56 let mut crossings = 0u32;
57 let mut was_positive = input[0] > 0.;
58
59 for &sample in input {
60 let is_positive = sample > 0.;
61 if was_positive != is_positive {
62 crossings += 1;
63 was_positive = is_positive;
64 }
65 }
66 crossings
67}
68
69pub fn reflect_pad(array: &[f32], pad: usize) -> Vec<f32> {
71 let prefix = array[1..=pad].iter().rev().copied().collect::<Vec<f32>>();
72 let suffix = array[(array.len() - 2) - pad + 1..array.len() - 1]
73 .iter()
74 .rev()
75 .copied()
76 .collect::<Vec<f32>>();
77 let mut output = Vec::with_capacity(prefix.len() + array.len() + suffix.len());
78 output.extend(prefix);
79 output.extend(array);
80 output.extend(suffix);
81 output
82}
83
84pub fn stft(signal: &[f32], window_length: usize, hop_length: usize) -> Array2<f64> {
86 let mut stft = Array2::zeros((
87 (signal.len() as f32 / hop_length as f32).ceil() as usize,
88 window_length / 2 + 1,
89 ));
90 let signal = reflect_pad(signal, window_length / 2);
91
92 let hann_window = Array::from_vec(generate_hann_window(window_length));
94
95 let fft = plan_fft_forward(window_length);
96
97 for (window, mut stft_col) in signal
98 .windows(window_length)
99 .step_by(hop_length)
100 .zip(stft.rows_mut())
101 {
102 let mut fft_input = (arr1(window) * &hann_window).mapv(|x| Complex::new(x, 0.));
103 match fft_input.as_slice_mut() {
104 Some(s) => fft.process(s),
105 None => {
106 fft.process(&mut fft_input.to_vec());
107 }
108 };
109 stft_col.assign(
110 &fft_input
111 .slice(s![..window_length / 2 + 1])
112 .mapv(|x| (x.re * x.re + x.im * x.im).sqrt() as f64),
113 );
114 }
115 stft.permuted_axes((1, 0))
116}
117
118pub fn hz_to_octs_inplace(
120 frequencies: &mut Array1<f64>,
121 tuning: f64,
122 bins_per_octave: u32,
123) -> &mut Array1<f64> {
124 let a440 = 440.0 * 2_f64.powf(tuning / f64::from(bins_per_octave));
125 *frequencies /= a440 / 16.;
126 frequencies.mapv_inplace(f64::log2);
127 frequencies
128}
129
130pub fn convolve(input: &Array1<f64>, kernel: &Array1<f64>) -> Array1<f64> {
132 let mut common_length = input.len() + kernel.len();
133 if !common_length.is_multiple_of(2) {
134 common_length -= 1;
135 }
136 let mut padded_input = Array::from_elem(
137 common_length,
138 Complex {
139 re: 0.0_f64,
140 im: 0.0,
141 },
142 );
143 padded_input
144 .slice_mut(s![..input.len()])
145 .assign(&input.mapv(|x| Complex::new(x, 0.)));
146 let mut padded_kernel = Array::from_elem(
147 common_length,
148 Complex {
149 re: 0.0_f64,
150 im: 0.0,
151 },
152 );
153 padded_kernel
154 .slice_mut(s![..kernel.len()])
155 .assign(&kernel.mapv(|x| Complex::new(x, 0.)));
156
157 let mut planner = FftPlanner::new();
158 let forward = planner.plan_fft_forward(common_length);
159 forward.process(padded_input.as_slice_mut().unwrap());
160 forward.process(padded_kernel.as_slice_mut().unwrap());
161
162 let mut multiplication = padded_input * padded_kernel;
163
164 let back = planner.plan_fft_inverse(common_length);
165 back.process(multiplication.as_slice_mut().unwrap());
166
167 let multiplication_length = multiplication.len() as f64;
168 let multiplication = multiplication
169 .slice_move(s![
170 (kernel.len() - 1) / 2..(kernel.len() - 1) / 2 + input.len()
171 ])
172 .mapv(|x| x.re);
173 multiplication / multiplication_length
174}
175
176pub fn std_deviation(values: &[f32]) -> f32 {
178 if values.len() <= 1 {
179 return 0.0;
180 }
181 let m = mean(values);
182 let variance = values.iter().map(|&x| (x - m) * (x - m)).sum::<f32>() / values.len() as f32;
183 variance.sqrt()
184}
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189
190 #[test]
191 fn test_mean() {
192 let numbers = vec![0.0, 1.0, 2.0, 3.0, 4.0];
193 assert_eq!(2.0, mean(&numbers));
194 }
195
196 #[test]
197 fn test_geometric_mean() {
198 let numbers = vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0];
199 assert_eq!(0.0, geometric_mean(&numbers));
200
201 let numbers = vec![4.0, 2.0, 1.0, 4.0, 2.0, 1.0, 2.0, 2.0];
202 assert!(0.0001 > (2.0 - geometric_mean(&numbers)).abs());
203 }
204
205 #[test]
206 #[should_panic(expected = "geometric_mean input length must be a multiple of 8")]
207 fn test_geometric_mean_panics_on_non_multiple_of_8() {
208 let numbers = vec![1.0f32; 9];
211 geometric_mean(&numbers);
212 }
213
214 #[test]
215 fn test_number_crossings() {
216 let input = vec![-1.0, 1.0, -1.0, 1.0];
217 assert_eq!(3, number_crossings(&input));
218
219 let input = vec![1.0, 1.0, 1.0];
220 assert_eq!(0, number_crossings(&input));
221 }
222
223 #[test]
224 fn test_normalize() {
225 assert!((0.0 - normalize(0.5, 0.0, 1.0)).abs() < 1e-6);
226 assert!((-1.0 - normalize(0.0, 0.0, 1.0)).abs() < 1e-6);
227 assert!((1.0 - normalize(1.0, 0.0, 1.0)).abs() < 1e-6);
228 }
229
230 #[test]
231 fn test_std_deviation() {
232 let values = vec![2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0];
233 assert!((2.0 - std_deviation(&values)).abs() < 0.01);
234 }
235
236 #[test]
237 fn test_reflect_pad() {
238 let array: Vec<f32> = (0..100).map(|x| x as f32).collect();
239 let output = reflect_pad(&array, 3);
240 assert_eq!(&output[..4], &[3.0, 2.0, 1.0, 0.0]);
241 assert_eq!(&output[3..103], &array[..]);
242 assert_eq!(&output[103..106], &[98.0, 97.0, 96.0]);
243 }
244}