use super::utils::{geometric_mean, mean, normalize, std_deviation};
use rustfft::FftPlanner;
use rustfft::num_complex::Complex;
use std::f32::consts::PI;
const WINDOW_SIZE: usize = 512;
const HOP_SIZE: usize = WINDOW_SIZE / 4;
pub fn compute_spectral_features(samples: &[f32], sample_rate: u32) -> Vec<f32> {
let sr = sample_rate as f32;
let half_sr = sr / 2.0;
let n_bins = WINDOW_SIZE / 2 + 1;
let hann: Vec<f32> = (0..WINDOW_SIZE)
.map(|n| 0.5 - 0.5 * f32::cos(2.0 * PI * n as f32 / WINDOW_SIZE as f32))
.collect();
let mut planner = FftPlanner::<f32>::new();
let fft = planner.plan_fft_forward(WINDOW_SIZE);
let mut values_centroid = Vec::new();
let mut values_rolloff = Vec::new();
let mut values_flatness = Vec::new();
for chunk in samples.windows(WINDOW_SIZE).step_by(HOP_SIZE) {
let mut buffer: Vec<Complex<f32>> = chunk
.iter()
.zip(hann.iter())
.map(|(&s, &w)| Complex::new(s * w, 0.0))
.collect();
fft.process(&mut buffer);
let norms: Vec<f32> = buffer[..n_bins].iter().map(|c| c.norm()).collect();
let sum_mag: f32 = norms.iter().sum();
let centroid_bin = if sum_mag > 0.0 {
norms
.iter()
.enumerate()
.map(|(i, &m)| i as f32 * m)
.sum::<f32>()
/ sum_mag
} else {
0.0
};
let centroid_freq = centroid_bin * sr / WINDOW_SIZE as f32;
values_centroid.push(centroid_freq);
let total_energy: f32 = norms.iter().map(|&m| m * m).sum();
let threshold = 0.95 * total_energy;
let mut cumulative = 0.0;
let mut rolloff_bin = 0.0_f32;
for (i, &m) in norms.iter().enumerate() {
cumulative += m * m;
if cumulative >= threshold {
rolloff_bin = i as f32;
break;
}
}
if rolloff_bin > WINDOW_SIZE as f32 / 2.0 {
rolloff_bin = WINDOW_SIZE as f32 / 2.0;
}
let rolloff_freq = rolloff_bin * sr / WINDOW_SIZE as f32;
values_rolloff.push(rolloff_freq);
let geo = geometric_mean(&norms[..256]);
if geo == 0.0 {
values_flatness.push(0.0);
} else {
let flatness = geo / mean(&norms);
values_flatness.push(flatness);
}
}
let centroid_mean = normalize(mean(&values_centroid), 0.0, half_sr);
let centroid_std = normalize(std_deviation(&values_centroid), 0.0, half_sr);
let rolloff_mean = normalize(mean(&values_rolloff), 0.0, half_sr);
let rolloff_std = normalize(std_deviation(&values_rolloff), 0.0, half_sr);
let flatness_mean = normalize(mean(&values_flatness), 0.0, 1.0);
let flatness_std = normalize(std_deviation(&values_flatness), 0.0, 1.0);
vec![
centroid_mean,
centroid_std,
rolloff_mean,
rolloff_std,
flatness_mean,
flatness_std,
]
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_spectral_silence() {
let silence = vec![0.0; 1024];
let features = compute_spectral_features(&silence, 22050);
assert_eq!(features.len(), 6);
for &f in &features {
assert!(f <= -0.99, "expected ~-1 for silence, got {f}");
}
}
#[test]
fn test_spectral_features_length() {
let signal: Vec<f32> = (0..22050)
.map(|i| (2.0 * PI * 440.0 * i as f32 / 22050.0).sin())
.collect();
let features = compute_spectral_features(&signal, 22050);
assert_eq!(features.len(), 6);
for &f in &features {
assert!((-1.0..=1.0).contains(&f), "feature out of range: {f}");
}
}
}