use super::utils::normalize;
use rustfft::FftPlanner;
use rustfft::num_complex::Complex;
use std::f32::consts::PI;
const WINDOW_SIZE: usize = 512;
const HOP_SIZE: usize = WINDOW_SIZE / 2; const MAX_BPM: f32 = 206.0;
const MIN_BPM: f32 = 0.0;
pub fn compute_tempo(samples: &[f32], sample_rate: u32) -> f32 {
let onset_envelope = compute_onset_envelope(samples, sample_rate);
if onset_envelope.len() < 4 {
return -1.0; }
let max_onset = onset_envelope.iter().copied().fold(0.0f32, f32::max);
if max_onset < 1e-6 {
return -1.0; }
let bpm = estimate_bpm_autocorrelation(&onset_envelope, sample_rate);
if bpm <= 0.0 {
return -1.0;
}
normalize(bpm, MIN_BPM, MAX_BPM)
}
fn compute_onset_envelope(samples: &[f32], _sample_rate: u32) -> Vec<f32> {
let n_bins = WINDOW_SIZE / 2 + 1;
let hann: Vec<f32> = (0..WINDOW_SIZE)
.map(|n| 0.5 - 0.5 * f32::cos(2.0 * PI * n as f32 / WINDOW_SIZE as f32))
.collect();
let mut planner = FftPlanner::<f32>::new();
let fft = planner.plan_fft_forward(WINDOW_SIZE);
let mut prev_spectrum = vec![0.0f32; n_bins];
let mut onset_envelope = Vec::new();
for chunk in samples.windows(WINDOW_SIZE).step_by(HOP_SIZE) {
let mut buffer: Vec<Complex<f32>> = chunk
.iter()
.zip(hann.iter())
.map(|(&s, &w)| Complex::new(s * w, 0.0))
.collect();
fft.process(&mut buffer);
let spectrum: Vec<f32> = buffer[..n_bins].iter().map(|c| c.norm()).collect();
let flux: f32 = spectrum
.iter()
.zip(prev_spectrum.iter())
.map(|(&curr, &prev)| (curr - prev).max(0.0))
.sum();
onset_envelope.push(flux);
prev_spectrum = spectrum;
}
onset_envelope
}
fn estimate_bpm_autocorrelation(onset_envelope: &[f32], sample_rate: u32) -> f32 {
let n = onset_envelope.len();
if n < 4 {
return 0.0;
}
let frame_rate = sample_rate as f32 / HOP_SIZE as f32;
let min_lag = (frame_rate * 60.0 / MAX_BPM).ceil() as usize;
let max_lag = (frame_rate * 60.0 / 30.0).floor() as usize; let max_lag = max_lag.min(n / 2);
if min_lag >= max_lag || max_lag >= n {
return 0.0;
}
let mean_val = onset_envelope.iter().sum::<f32>() / n as f32;
let centered: Vec<f32> = onset_envelope.iter().map(|&x| x - mean_val).collect();
let mut best_lag = min_lag;
let mut best_score = f32::NEG_INFINITY;
for lag in min_lag..=max_lag {
let mut corr = 0.0f32;
for i in 0..n - lag {
corr += centered[i] * centered[i + lag];
}
corr /= (n - lag) as f32;
let bpm = frame_rate * 60.0 / lag as f32;
let prior = (-0.5 * ((bpm - 120.0) / 40.0).powi(2)).exp();
let score = corr * prior;
if score > best_score {
best_score = score;
best_lag = lag;
}
}
frame_rate * 60.0 / best_lag as f32
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_tempo_silence() {
let silence = vec![0.0; 22050 * 10];
let tempo = compute_tempo(&silence, 22050);
assert_eq!(-1.0, tempo);
}
#[test]
fn test_tempo_60bpm() {
let sr = 22050u32;
let duration_secs = 30;
let total_samples = sr as usize * duration_secs;
let mut signal = vec![0.0f32; total_samples];
for beat in 0..duration_secs {
let pos = beat * sr as usize;
for i in 0..100 {
if pos + i < total_samples {
signal[pos + i] = 1.0;
}
}
}
let tempo = compute_tempo(&signal, sr);
assert!(tempo > -0.8 && tempo < 0.0, "tempo = {tempo}");
}
}