math_audio_dsp/audio_features/
utils.rs1use ndarray::{Array, Array1, Array2, arr1, s};
7use rustfft::FftPlanner;
8use rustfft::num_complex::Complex;
9
10use crate::analysis::plan_fft_forward;
11use crate::stft::generate_hann_window;
12
13pub fn normalize(value: f32, min_value: f32, max_value: f32) -> f32 {
15 2. * (value - min_value) / (max_value - min_value) - 1.
16}
17
18pub fn mean(input: &[f32]) -> f32 {
20 if input.is_empty() {
21 return 0.0;
22 }
23 input.iter().sum::<f32>() / input.len() as f32
24}
25
26pub fn geometric_mean(input: &[f32]) -> f32 {
29 let mut exponents: i32 = 0;
30 let mut mantissas: f64 = 1.;
31 for ch in input.chunks_exact(8) {
32 let mut m = (ch[0] as f64 * ch[1] as f64) * (ch[2] as f64 * ch[3] as f64);
33 m *= 3.273390607896142e150; m *= (ch[4] as f64 * ch[5] as f64) * (ch[6] as f64 * ch[7] as f64);
35 if m == 0. {
36 return 0.;
37 }
38 exponents += (m.to_bits() >> 52) as i32;
39 mantissas *= f64::from_bits((m.to_bits() & 0xFFFFFFFFFFFFF) | 0x3FF0000000000000);
40 }
41
42 let n = input.len() as u32;
43 (((mantissas as f32).log2() + exponents as f32) / n as f32 - (1023. + 500.) / 8.).exp2()
44}
45
46pub fn number_crossings(input: &[f32]) -> u32 {
48 if input.is_empty() {
49 return 0;
50 }
51 let mut crossings = 0u32;
52 let mut was_positive = input[0] > 0.;
53
54 for &sample in input {
55 let is_positive = sample > 0.;
56 if was_positive != is_positive {
57 crossings += 1;
58 was_positive = is_positive;
59 }
60 }
61 crossings
62}
63
64pub fn reflect_pad(array: &[f32], pad: usize) -> Vec<f32> {
66 let prefix = array[1..=pad].iter().rev().copied().collect::<Vec<f32>>();
67 let suffix = array[(array.len() - 2) - pad + 1..array.len() - 1]
68 .iter()
69 .rev()
70 .copied()
71 .collect::<Vec<f32>>();
72 let mut output = Vec::with_capacity(prefix.len() + array.len() + suffix.len());
73 output.extend(prefix);
74 output.extend(array);
75 output.extend(suffix);
76 output
77}
78
79pub fn stft(signal: &[f32], window_length: usize, hop_length: usize) -> Array2<f64> {
81 let mut stft = Array2::zeros((
82 (signal.len() as f32 / hop_length as f32).ceil() as usize,
83 window_length / 2 + 1,
84 ));
85 let signal = reflect_pad(signal, window_length / 2);
86
87 let hann_window = Array::from_vec(generate_hann_window(window_length));
89
90 let fft = plan_fft_forward(window_length);
91
92 for (window, mut stft_col) in signal
93 .windows(window_length)
94 .step_by(hop_length)
95 .zip(stft.rows_mut())
96 {
97 let mut fft_input = (arr1(window) * &hann_window).mapv(|x| Complex::new(x, 0.));
98 match fft_input.as_slice_mut() {
99 Some(s) => fft.process(s),
100 None => {
101 fft.process(&mut fft_input.to_vec());
102 }
103 };
104 stft_col.assign(
105 &fft_input
106 .slice(s![..window_length / 2 + 1])
107 .mapv(|x| (x.re * x.re + x.im * x.im).sqrt() as f64),
108 );
109 }
110 stft.permuted_axes((1, 0))
111}
112
113pub fn hz_to_octs_inplace(
115 frequencies: &mut Array1<f64>,
116 tuning: f64,
117 bins_per_octave: u32,
118) -> &mut Array1<f64> {
119 let a440 = 440.0 * 2_f64.powf(tuning / f64::from(bins_per_octave));
120 *frequencies /= a440 / 16.;
121 frequencies.mapv_inplace(f64::log2);
122 frequencies
123}
124
125pub fn convolve(input: &Array1<f64>, kernel: &Array1<f64>) -> Array1<f64> {
127 let mut common_length = input.len() + kernel.len();
128 if !common_length.is_multiple_of(2) {
129 common_length -= 1;
130 }
131 let mut padded_input = Array::from_elem(
132 common_length,
133 Complex {
134 re: 0.0_f64,
135 im: 0.0,
136 },
137 );
138 padded_input
139 .slice_mut(s![..input.len()])
140 .assign(&input.mapv(|x| Complex::new(x, 0.)));
141 let mut padded_kernel = Array::from_elem(
142 common_length,
143 Complex {
144 re: 0.0_f64,
145 im: 0.0,
146 },
147 );
148 padded_kernel
149 .slice_mut(s![..kernel.len()])
150 .assign(&kernel.mapv(|x| Complex::new(x, 0.)));
151
152 let mut planner = FftPlanner::new();
153 let forward = planner.plan_fft_forward(common_length);
154 forward.process(padded_input.as_slice_mut().unwrap());
155 forward.process(padded_kernel.as_slice_mut().unwrap());
156
157 let mut multiplication = padded_input * padded_kernel;
158
159 let back = planner.plan_fft_inverse(common_length);
160 back.process(multiplication.as_slice_mut().unwrap());
161
162 let multiplication_length = multiplication.len() as f64;
163 let multiplication = multiplication
164 .slice_move(s![
165 (kernel.len() - 1) / 2..(kernel.len() - 1) / 2 + input.len()
166 ])
167 .mapv(|x| x.re);
168 multiplication / multiplication_length
169}
170
171pub fn std_deviation(values: &[f32]) -> f32 {
173 if values.len() <= 1 {
174 return 0.0;
175 }
176 let m = mean(values);
177 let variance = values.iter().map(|&x| (x - m) * (x - m)).sum::<f32>() / values.len() as f32;
178 variance.sqrt()
179}
180
181#[cfg(test)]
182mod tests {
183 use super::*;
184
185 #[test]
186 fn test_mean() {
187 let numbers = vec![0.0, 1.0, 2.0, 3.0, 4.0];
188 assert_eq!(2.0, mean(&numbers));
189 }
190
191 #[test]
192 fn test_geometric_mean() {
193 let numbers = vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0];
194 assert_eq!(0.0, geometric_mean(&numbers));
195
196 let numbers = vec![4.0, 2.0, 1.0, 4.0, 2.0, 1.0, 2.0, 2.0];
197 assert!(0.0001 > (2.0 - geometric_mean(&numbers)).abs());
198 }
199
200 #[test]
201 fn test_number_crossings() {
202 let input = vec![-1.0, 1.0, -1.0, 1.0];
203 assert_eq!(3, number_crossings(&input));
204
205 let input = vec![1.0, 1.0, 1.0];
206 assert_eq!(0, number_crossings(&input));
207 }
208
209 #[test]
210 fn test_normalize() {
211 assert!((0.0 - normalize(0.5, 0.0, 1.0)).abs() < 1e-6);
212 assert!((-1.0 - normalize(0.0, 0.0, 1.0)).abs() < 1e-6);
213 assert!((1.0 - normalize(1.0, 0.0, 1.0)).abs() < 1e-6);
214 }
215
216 #[test]
217 fn test_std_deviation() {
218 let values = vec![2.0, 4.0, 4.0, 4.0, 5.0, 5.0, 7.0, 9.0];
219 assert!((2.0 - std_deviation(&values)).abs() < 0.01);
220 }
221
222 #[test]
223 fn test_reflect_pad() {
224 let array: Vec<f32> = (0..100).map(|x| x as f32).collect();
225 let output = reflect_pad(&array, 3);
226 assert_eq!(&output[..4], &[3.0, 2.0, 1.0, 0.0]);
227 assert_eq!(&output[3..103], &array[..]);
228 assert_eq!(&output[103..106], &[98.0, 97.0, 96.0]);
229 }
230}