use crate::buffer::AudioBuffer;
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct BeatResult {
pub bpm: f32,
pub confidence: f32,
pub beat_positions: Vec<usize>,
}
pub fn detect_tempo(buf: &AudioBuffer, min_bpm: f32, max_bpm: f32) -> crate::Result<BeatResult> {
let duration_s = buf.frames as f32 / buf.sample_rate as f32;
if duration_s < 1.0 {
return Err(crate::NadaError::Dsp(
"detect_tempo requires at least 1 second of audio".into(),
));
}
let window_size = 2048;
let hop_size = 512;
let onset_env = onset_strength_envelope(buf, window_size, hop_size)?;
if onset_env.len() < 4 {
return Ok(BeatResult {
bpm: 0.0,
confidence: 0.0,
beat_positions: Vec::new(),
});
}
let onset_sr = buf.sample_rate as f32 / hop_size as f32; let min_lag = (onset_sr * 60.0 / max_bpm) as usize; let max_lag = (onset_sr * 60.0 / min_bpm) as usize; let max_lag = max_lag.min(onset_env.len() / 2);
if min_lag >= max_lag || max_lag == 0 {
return Ok(BeatResult {
bpm: 0.0,
confidence: 0.0,
beat_positions: Vec::new(),
});
}
let autocorr = autocorrelation(&onset_env, min_lag, max_lag);
let (best_lag, best_val, max_val) = find_peak(&autocorr, min_lag);
if max_val <= 0.0 {
return Ok(BeatResult {
bpm: 0.0,
confidence: 0.0,
beat_positions: Vec::new(),
});
}
let bpm = onset_sr * 60.0 / best_lag as f32;
let confidence = (best_val / max_val).clamp(0.0, 1.0);
let beat_positions = extract_beats(buf, &onset_env, hop_size, best_lag);
Ok(BeatResult {
bpm,
confidence,
beat_positions,
})
}
fn onset_strength_envelope(
buf: &AudioBuffer,
window_size: usize,
hop_size: usize,
) -> crate::Result<Vec<f32>> {
let spectrogram = super::stft::stft(buf, window_size, hop_size)?;
let frames = &spectrogram.frames;
if frames.len() < 2 {
return Ok(Vec::new());
}
let mut envelope = Vec::with_capacity(frames.len());
envelope.push(0.0);
for i in 1..frames.len() {
let flux: f32 = frames[i]
.iter()
.zip(frames[i - 1].iter())
.map(|(&curr, &prev)| (curr - prev).max(0.0))
.sum();
envelope.push(flux);
}
let max_flux = envelope.iter().cloned().fold(0.0f32, f32::max);
if max_flux > 0.0 {
for v in &mut envelope {
*v /= max_flux;
}
}
Ok(envelope)
}
#[inline]
fn autocorrelation(signal: &[f32], min_lag: usize, max_lag: usize) -> Vec<f32> {
let n = signal.len();
let mut result = vec![0.0f32; max_lag + 1];
for lag in min_lag..=max_lag {
let mut sum = 0.0f32;
let count = n - lag;
for i in 0..count {
sum += signal[i] * signal[i + lag];
}
result[lag] = sum / count as f32;
}
result
}
fn find_peak(autocorr: &[f32], min_lag: usize) -> (usize, f32, f32) {
let mut best_lag = min_lag;
let mut best_val = 0.0f32;
let mut max_val = 0.0f32;
for (lag, &val) in autocorr.iter().enumerate().skip(min_lag) {
if val > max_val {
max_val = val;
}
if val > best_val {
best_val = val;
best_lag = lag;
}
}
(best_lag, best_val, max_val)
}
fn extract_beats(
buf: &AudioBuffer,
onset_env: &[f32],
hop_size: usize,
period_frames: usize,
) -> Vec<usize> {
if period_frames == 0 || onset_env.is_empty() {
return Vec::new();
}
let mut beats = Vec::new();
let half_window = period_frames / 4;
let mut pos = 0usize;
while pos < onset_env.len() {
let start = pos.saturating_sub(half_window);
let end = (pos + half_window).min(onset_env.len());
let mut best_idx = pos.min(onset_env.len() - 1);
let mut best_strength = 0.0f32;
for (i, &strength) in onset_env.iter().enumerate().take(end).skip(start) {
if strength > best_strength {
best_strength = strength;
best_idx = i;
}
}
let sample_pos = best_idx * hop_size;
if sample_pos < buf.frames * buf.channels as usize {
beats.push(sample_pos);
}
pos += period_frames;
}
beats
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn silence_no_beats() {
let buf = AudioBuffer::silence(1, 44100 * 2, 44100);
let result = detect_tempo(&buf, 60.0, 200.0).unwrap();
assert!(result.beat_positions.is_empty() || result.confidence < 0.1);
}
#[test]
fn click_track_120bpm() {
let sr = 44100;
let duration = 4; let total_samples = sr * duration;
let mut samples = vec![0.0f32; total_samples];
let click_interval = sr / 2;
for beat in 0..8 {
let start = beat * click_interval;
let click_len = sr / 100;
for i in 0..click_len.min(total_samples - start) {
samples[start + i] =
0.9 * (2.0 * std::f32::consts::PI * 1000.0 * i as f32 / sr as f32).sin();
}
}
let buf = AudioBuffer::from_interleaved(samples, 1, sr as u32).unwrap();
let result = detect_tempo(&buf, 60.0, 200.0).unwrap();
assert!(
(result.bpm - 120.0).abs() < 15.0,
"expected ~120 BPM, got {} (confidence: {})",
result.bpm,
result.confidence
);
}
#[test]
fn too_short_buffer() {
let buf = AudioBuffer::from_interleaved(vec![0.5; 22050], 1, 44100).unwrap();
assert!(detect_tempo(&buf, 60.0, 200.0).is_err());
}
#[test]
fn beat_positions_ordered() {
let sr = 44100;
let total_samples = sr * 3;
let mut samples = vec![0.0f32; total_samples];
let click_interval = sr / 2;
for beat in 0..6 {
let start = beat * click_interval;
let click_len = sr / 100;
for i in 0..click_len.min(total_samples - start) {
samples[start + i] = 0.9;
}
}
let buf = AudioBuffer::from_interleaved(samples, 1, sr as u32).unwrap();
let result = detect_tempo(&buf, 60.0, 200.0).unwrap();
for w in result.beat_positions.windows(2) {
assert!(w[0] < w[1], "beats not ordered: {} >= {}", w[0], w[1]);
}
}
}