1use likely_stable::{if_likely, unlikely};
2use log::warn;
3use ndarray::{Array, Array1, Array2, arr1, s};
4use rustfft::FftPlanner;
5use rustfft::num_complex::Complex;
6use std::f32::consts::PI;
7
8use crate::Feature;
9
10#[must_use]
11#[inline]
12pub fn reflect_pad(array: &[f32], pad: usize) -> Vec<f32> {
13 debug_assert!(pad < array.len(), "Padding is too large");
14 let prefix = array[1..=pad].iter().rev().copied();
15 let suffix = array[(array.len() - 2) - pad + 1..array.len() - 1]
16 .iter()
17 .rev()
18 .copied();
19 let mut output = Vec::with_capacity(prefix.len() + array.len() + suffix.len());
20
21 output.extend(prefix);
22 output.extend(array);
23 output.extend(suffix);
24 output
25}
26
27#[must_use]
29#[allow(clippy::missing_inline_in_public_items)]
30pub fn stft(signal: &[f32], window_length: usize, hop_length: usize) -> Array2<f32> {
31 debug_assert!(
32 window_length.is_multiple_of(2),
33 "Window length must be even"
34 );
35 debug_assert!(window_length < signal.len(), "Signal is too short");
36 debug_assert!(hop_length < window_length, "Hop length is too large");
37 let half_window_length = window_length / 2;
38 let mut stft = Array2::zeros((signal.len().div_ceil(hop_length), half_window_length + 1));
41 let signal = reflect_pad(signal, half_window_length);
42
43 #[allow(clippy::cast_precision_loss)]
45 let mut hann_window = -0.5
46 * Array::from_shape_fn(window_length + 1, |n| {
47 2. * n as f32 * PI / (window_length as f32)
48 })
49 .cos()
50 + 0.5;
51 hann_window = hann_window.slice_move(s![0..window_length]);
52 let mut planner = FftPlanner::new();
53 let fft = planner.plan_fft_forward(window_length);
54
55 #[allow(unused, reason = "it's not unused, but macro confuses poor clippy")]
56 for (window, mut stft_col) in signal
57 .windows(window_length)
58 .step_by(hop_length)
59 .zip(stft.rows_mut())
60 {
61 let mut signal = (arr1(window) * &hann_window).mapv(|x| Complex::new(x, 0.));
62
63 if_likely! {let Some(s) = signal.as_slice_mut() => {
64 fft.process(s);
65 } else {
66 warn!("non-contiguous slice found for stft; expect slow performances.");
67 fft.process(&mut signal.to_vec());
68 }}
69
70 stft_col.assign(
71 &signal
72 .slice(s![..=half_window_length])
73 .mapv(|x| x.re.hypot(x.im)),
74 );
75 }
76 stft.permuted_axes((1, 0))
77}
78
79#[allow(clippy::cast_precision_loss)]
80pub(crate) fn mean(input: &[f32]) -> f32 {
81 if unlikely(input.is_empty()) {
82 return 0.;
83 }
84 input.iter().sum::<f32>() / input.len() as f32
85}
86
87pub(crate) trait Normalize {
88 const MAX_VALUE: Feature;
89 const MIN_VALUE: Feature;
90
91 fn normalize(&self, value: Feature) -> Feature {
92 2. * (value - Self::MIN_VALUE) / (Self::MAX_VALUE - Self::MIN_VALUE) - 1.
93 }
94}
95
96pub(crate) fn number_crossings(input: &[f32]) -> usize {
99 if unlikely(input.is_empty()) {
100 return 0;
101 }
102
103 input
104 .windows(2)
105 .filter(|w| (w[0] > 0.) != (w[1] > 0.))
106 .count()
107}
108
109#[must_use]
115#[allow(clippy::missing_inline_in_public_items)]
116pub fn geometric_mean(input: &[f32]) -> f32 {
117 debug_assert_eq!(input.len() % 8, 0, "Input size must be a multiple of 8");
118 if unlikely(input.is_empty()) {
119 return 0.;
120 }
121
122 let mut exponents: i32 = 0;
123 let mut mantissas: f64 = 1.;
124 for ch in input.chunks_exact(8) {
125 let mut m = (f64::from(ch[0]) * f64::from(ch[1])) * (f64::from(ch[2]) * f64::from(ch[3]));
126 m *= 3.273_390_607_896_142e150; m *= (f64::from(ch[4]) * f64::from(ch[5])) * (f64::from(ch[6]) * f64::from(ch[7]));
128 if unlikely(m == 0.) {
129 return 0.;
130 }
131 exponents += (m.to_bits() >> 52) as i32;
132 mantissas *= f64::from_bits((m.to_bits() & 0x000F_FFFF_FFFF_FFFF) | 0x3FF0_0000_0000_0000);
133 }
134
135 #[allow(clippy::cast_possible_truncation)]
136 let n = input.len() as u32;
137 #[allow(clippy::cast_possible_truncation)]
138 let result = (((mantissas.log2() + f64::from(exponents)) / f64::from(n) - (1023. + 500.) / 8.)
139 .exp2()) as f32;
140 result
141}
142
143pub(crate) fn hz_to_octs_inplace(
144 frequencies: &mut Array1<f32>,
145 tuning: f32,
146 bins_per_octave: u32,
147) -> &mut Array1<f32> {
148 #[allow(clippy::cast_precision_loss)]
149 let a440 = 440.0 * (tuning / bins_per_octave as f32).exp2();
150
151 *frequencies /= a440 / 16.;
152 frequencies.mapv_inplace(f32::log2);
153 frequencies
154}
155
156#[cfg(test)]
157mod tests {
158 use super::*;
159 use crate::decoder::{Decoder as DecoderTrait, MecompDecoder as Decoder};
160 use ndarray::{Array, Array2, arr1};
161 use ndarray_npy::ReadNpyExt;
162 use std::{fs::File, path::Path};
163
164 #[test]
165 fn test_mean() {
166 let numbers = vec![0.0, 1.0, 2.0, 3.0, 4.0];
167 let mean = mean(&numbers);
168 assert!(f32::EPSILON > (2.0 - mean).abs(), "{mean} !~= 2.0");
169 }
170
171 #[test]
172 #[allow(clippy::too_many_lines)]
173 fn test_geometric_mean() {
174 let numbers = vec![0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0];
175 let mean = geometric_mean(&numbers);
176 assert!(f32::EPSILON > (0.0 - mean).abs(), "{mean} !~= 0.0");
177
178 let numbers = vec![4.0, 2.0, 1.0, 4.0, 2.0, 1.0, 2.0, 2.0];
179 let mean = geometric_mean(&numbers);
180 assert!(0.0001 > (2.0 - mean).abs(), "{mean} !~= 2.0");
181
182 let numbers = vec![256., 4.0, 2.0, 1.0, 4.0, 2.0, 1.0, 2.0];
184 let mean = geometric_mean(&numbers);
185 assert!(
186 0.0001 > (3.668_016_2 - mean).abs(),
187 "{mean} !~= {}",
188 3.668_016_172_818_685
189 );
190
191 let subnormal = vec![4.0, 2.0, 1.0, 4.0, 2.0, 1.0, 2.0, 1.0e-40_f32];
192 let mean = geometric_mean(&subnormal);
193 assert!(
194 0.0001 > (1.834_008e-5 - mean).abs(),
195 "{} !~= {}",
196 mean,
197 1.834_008_086_409_341_7e-5
198 );
199
200 let maximum = vec![2_f32.powi(65); 256];
201 let mean = geometric_mean(&maximum);
202 assert!(
203 0.0001 > (2_f32.powi(65) - mean.abs()),
204 "{} !~= {}",
205 mean,
206 2_f32.powi(65)
207 );
208
209 let input = [
210 0.024_454_033,
211 0.088_096_89,
212 0.445_543_62,
213 0.827_535_03,
214 0.158_220_93,
215 1.444_224_5,
216 3.697_138_5,
217 3.678_955_6,
218 1.598_157_2,
219 1.017_271_8,
220 1.443_609_6,
221 3.145_710_2,
222 2.764_110_8,
223 0.839_523_5,
224 0.248_968_29,
225 0.070_631_73,
226 0.355_419_4,
227 0.352_001_4,
228 0.797_365_1,
229 0.661_970_8,
230 0.784_104,
231 0.876_795_7,
232 0.287_382_66,
233 0.048_841_28,
234 0.322_706_5,
235 0.334_907_47,
236 0.185_888_75,
237 0.135_449_42,
238 0.140_177_46,
239 0.111_815_82,
240 0.152_631_61,
241 0.221_993_12,
242 0.056_798_387,
243 0.083_892_57,
244 0.070_009_65,
245 0.202_903_29,
246 0.370_717_38,
247 0.231_543_18,
248 0.023_348_59,
249 0.013_220_183,
250 0.035_887_096,
251 0.029_505_49,
252 0.090_338_57,
253 0.176_795_04,
254 0.081_421_87,
255 0.003_326_808_6,
256 0.012_269_007,
257 0.016_257_336,
258 0.027_027_424,
259 0.017_253_408,
260 0.017_230_038,
261 0.021_678_915,
262 0.018_645_158,
263 0.005_417_136,
264 0.006_650_174_5,
265 0.020_159_671,
266 0.026_623_515,
267 0.005_166_793_7,
268 0.016_880_387,
269 0.009_935_223_5,
270 0.011_079_361,
271 0.013_200_151,
272 0.005_320_572_3,
273 0.005_070_289_6,
274 0.008_130_498,
275 0.009_006_041,
276 0.003_602_499_8,
277 0.006_440_387_6,
278 0.004_656_151,
279 0.002_513_185_8,
280 0.003_084_559_7,
281 0.008_722_531,
282 0.017_871_628,
283 0.022_656_294,
284 0.017_539_924,
285 0.009_439_588_5,
286 0.003_085_72,
287 0.001_358_616_6,
288 0.002_746_787_2,
289 0.005_413_010_3,
290 0.004_140_312,
291 0.000_143_587_14,
292 0.001_371_840_8,
293 0.004_472_961,
294 0.003_769_122,
295 0.003_259_129_6,
296 0.003_637_24,
297 0.002_445_332_2,
298 0.000_590_368_93,
299 0.000_647_898_65,
300 0.001_745_297,
301 0.000_867_165_5,
302 0.002_156_236_2,
303 0.001_075_606_8,
304 0.002_009_199_5,
305 0.001_537_388_5,
306 0.000_984_620_4,
307 0.000_292_002_49,
308 0.000_921_162_4,
309 0.000_535_111_8,
310 0.001_491_276_5,
311 0.002_065_137_5,
312 0.000_661_122_26,
313 0.000_850_054_26,
314 0.001_900_590_1,
315 0.000_639_584_5,
316 0.002_262_803,
317 0.003_094_018_2,
318 0.002_089_161_7,
319 0.001_215_059,
320 0.001_311_408_4,
321 0.000_470_959,
322 0.000_665_480_7,
323 0.001_430_32,
324 0.001_791_889_3,
325 0.000_863_200_75,
326 0.000_560_445_5,
327 0.000_828_417_54,
328 0.000_669_453_9,
329 0.000_822_765,
330 0.000_616_575_8,
331 0.001_189_319,
332 0.000_730_024_5,
333 0.000_623_748_1,
334 0.001_207_644_4,
335 0.001_474_674_2,
336 0.002_033_916,
337 0.001_500_169_9,
338 0.000_520_51,
339 0.000_445_643_32,
340 0.000_558_462_75,
341 0.000_897_786_64,
342 0.000_805_247_05,
343 0.000_726_536_44,
344 0.000_673_052_6,
345 0.000_994_064_5,
346 0.001_109_393_7,
347 0.001_295_099_7,
348 0.000_982_682_2,
349 0.000_876_651_8,
350 0.001_654_928_7,
351 0.000_929_064_35,
352 0.000_291_306_23,
353 0.000_250_490_47,
354 0.000_228_488_02,
355 0.000_269_673_15,
356 0.000_237_375_09,
357 0.000_969_406_1,
358 0.001_063_811_8,
359 0.000_793_428_86,
360 0.000_590_835_06,
361 0.000_476_389_9,
362 0.000_951_664_1,
363 0.000_692_231_46,
364 0.000_557_113_7,
365 0.000_851_769_7,
366 0.001_071_027_7,
367 0.000_610_243_9,
368 0.000_746_876_23,
369 0.000_849_898_44,
370 0.000_495_806_2,
371 0.000_526_994,
372 0.000_215_249_22,
373 0.000_096_684_314,
374 0.000_654_554_4,
375 0.001_220_697_3,
376 0.001_210_358_3,
377 0.000_920_454_33,
378 0.000_924_843_5,
379 0.000_812_128_4,
380 0.000_239_532_56,
381 0.000_931_822_4,
382 0.001_043_966_3,
383 0.000_483_734_15,
384 0.000_298_952_22,
385 0.000_484_425_4,
386 0.000_666_829_5,
387 0.000_998_398_5,
388 0.000_860_489_7,
389 0.000_183_153_23,
390 0.000_309_180_8,
391 0.000_542_646_2,
392 0.001_040_391_5,
393 0.000_755_456_6,
394 0.000_284_601_7,
395 0.000_600_979_3,
396 0.000_765_056_9,
397 0.000_562_810_46,
398 0.000_346_616_55,
399 0.000_236_224_32,
400 0.000_598_710_6,
401 0.000_295_684_27,
402 0.000_386_978_06,
403 0.000_584_258,
404 0.000_567_097_6,
405 0.000_613_644_4,
406 0.000_564_549_3,
407 0.000_235_384_52,
408 0.000_285_574_6,
409 0.000_385_352_93,
410 0.000_431_935_65,
411 0.000_731_246_5,
412 0.000_603_072_8,
413 0.001_033_130_8,
414 0.001_195_216_2,
415 0.000_824_500_7,
416 0.000_422_183_63,
417 0.000_821_760_16,
418 0.001_132_246,
419 0.000_891_406_73,
420 0.000_635_158_8,
421 0.000_372_681_56,
422 0.000_230_35,
423 0.000_628_649_3,
424 0.000_806_159_9,
425 0.000_661_622_15,
426 0.000_227_139_01,
427 0.000_214_694_96,
428 0.000_665_457_7,
429 0.000_513_901,
430 0.000_391_766_78,
431 0.001_079_094_7,
432 0.000_735_363_7,
433 0.000_171_665_73,
434 0.000_439_648_87,
435 0.000_295_145_3,
436 0.000_177_047_08,
437 0.000_182_958_97,
438 0.000_926_536_04,
439 0.000_832_408_3,
440 0.000_804_168_4,
441 0.001_131_809_3,
442 0.001_187_149_6,
443 0.000_806_948_8,
444 0.000_628_624_75,
445 0.000_591_386_1,
446 0.000_472_182_3,
447 0.000_163_652_31,
448 0.000_177_876_57,
449 0.000_425_363_75,
450 0.000_573_699_3,
451 0.000_434_679_24,
452 0.000_090_282_94,
453 0.000_172_573_55,
454 0.000_501_957_4,
455 0.000_614_716_8,
456 0.000_216_780_5,
457 0.000_148_974_3,
458 0.000_055_081_473,
459 0.000_296_264_13,
460 0.000_378_055_67,
461 0.000_147_361_96,
462 0.000_262_513_64,
463 0.000_162_118_42,
464 0.000_185_347_7,
465 0.000_138_735_4,
466 ];
467 assert!(
468 0.000_000_01 > (0.002_575_059_7 - geometric_mean(&input)).abs(),
469 "{} !~= 0.0025750597",
470 geometric_mean(&input)
471 );
472 }
473
474 #[test]
475 fn test_hz_to_octs_inplace() {
476 let mut frequencies = arr1(&[32., 64., 128., 256.]);
477 let expected = arr1(&[0.168_640_29, 1.168_640_29, 2.168_640_29, 3.168_640_29]);
478
479 hz_to_octs_inplace(&mut frequencies, 0.5, 10)
480 .iter()
481 .zip(expected.iter())
482 .for_each(|(x, y)| assert!(0.0001 > (x - y).abs(), "{x} !~= {y}"));
483 }
484
485 #[test]
486 fn test_compute_stft() {
487 let file = File::open("data/librosa-stft.npy").unwrap();
488 let expected_stft = Array2::<f32>::read_npy(file).unwrap();
489
490 let song = Decoder::new()
491 .unwrap()
492 .decode(Path::new("data/piano.flac"))
493 .unwrap();
494
495 let stft = stft(&song.samples, 2048, 512);
496
497 assert!(!stft.is_empty() && !expected_stft.is_empty(), "Empty STFT");
498 for (expected, actual) in expected_stft.iter().zip(stft.iter()) {
499 assert!(
501 0.0001 > (expected - actual).abs(),
502 "{expected} !~= {actual}"
503 );
504 }
505 }
506
507 #[test]
508 fn test_reflect_pad() {
509 let array = Array::range(0., 100_000., 1.);
510
511 let output = reflect_pad(array.as_slice().unwrap(), 3);
512 assert_eq!(&output[..4], &[3.0, 2.0, 1.0, 0.]);
513 assert_eq!(&output[3..100_003], array.to_vec());
514 assert_eq!(&output[100_003..100_006], &[99998.0, 99997.0, 99996.0]);
515 }
516}