#![allow(dead_code)]
#![allow(clippy::cast_precision_loss)]
#[derive(Debug, Clone)]
pub struct TempoHypothesis {
pub bpm: f32,
pub confidence: f32,
pub period_s: f32,
}
impl TempoHypothesis {
#[must_use]
pub fn new(bpm: f32, confidence: f32) -> Self {
let period_s = 60.0 / bpm;
Self {
bpm,
confidence,
period_s,
}
}
#[must_use]
pub fn score(&self) -> f32 {
self.confidence
}
#[must_use]
pub fn period_ms(&self) -> f32 {
self.period_s * 1000.0
}
}
#[derive(Debug, Clone, Copy)]
pub struct TempoInterval {
pub start_s: f32,
pub end_s: f32,
}
impl TempoInterval {
#[must_use]
pub fn new(start_s: f32, end_s: f32) -> Self {
Self { start_s, end_s }
}
#[must_use]
pub fn bpm(&self) -> f32 {
let dur = self.end_s - self.start_s;
if dur > 0.0 {
60.0 / dur
} else {
0.0
}
}
#[must_use]
pub fn duration_s(&self) -> f32 {
self.end_s - self.start_s
}
}
#[derive(Debug, Clone)]
pub struct BeatTracker {
sample_rate: f32,
hop_size: usize,
min_bpm: f32,
max_bpm: f32,
onset_times: Vec<f32>,
beat_positions: Vec<f32>,
}
impl BeatTracker {
#[must_use]
pub fn new(sample_rate: f32, hop_size: usize, min_bpm: f32, max_bpm: f32) -> Self {
Self {
sample_rate,
hop_size,
min_bpm,
max_bpm,
onset_times: Vec::new(),
beat_positions: Vec::new(),
}
}
#[allow(clippy::cast_precision_loss)]
pub fn add_onset(&mut self, frame_index: usize) {
let time_s = (frame_index * self.hop_size) as f32 / self.sample_rate;
self.onset_times.push(time_s);
}
pub fn add_onset_time(&mut self, time_s: f32) {
self.onset_times.push(time_s);
}
#[must_use]
pub fn detect_tempo(&mut self, n_hypotheses: usize) -> Vec<TempoHypothesis> {
if self.onset_times.len() < 2 {
return Vec::new();
}
let min_period = 60.0 / self.max_bpm;
let max_period = 60.0 / self.min_bpm;
let bin_width = 0.010_f32; let n_bins = ((max_period - min_period) / bin_width).ceil() as usize + 1;
let mut histogram = vec![0_u32; n_bins];
for i in 0..self.onset_times.len() {
for j in (i + 1)..self.onset_times.len().min(i + 8) {
let ioi = self.onset_times[j] - self.onset_times[i];
if ioi >= min_period && ioi <= max_period {
let bin = ((ioi - min_period) / bin_width) as usize;
if bin < n_bins {
histogram[bin] += 1;
}
}
}
}
let total: u32 = histogram.iter().sum();
if total == 0 {
return Vec::new();
}
let mut peaks: Vec<(usize, u32)> =
histogram.iter().enumerate().map(|(i, &v)| (i, v)).collect();
peaks.sort_by(|a, b| b.1.cmp(&a.1));
let hypotheses: Vec<TempoHypothesis> = peaks
.iter()
.take(n_hypotheses)
.filter(|(_, count)| *count > 0)
.map(|(bin, count)| {
let period = min_period + (*bin as f32 + 0.5) * bin_width;
let bpm = 60.0 / period;
let confidence = *count as f32 / total as f32;
TempoHypothesis::new(bpm, confidence)
})
.collect();
if let Some(top) = hypotheses.first() {
self.beat_positions = self.generate_beat_grid(top.period_s);
}
hypotheses
}
#[must_use]
pub fn beat_positions(&self) -> &[f32] {
&self.beat_positions
}
#[must_use]
pub fn onset_times(&self) -> &[f32] {
&self.onset_times
}
#[must_use]
pub fn beat_intervals(&self) -> Vec<TempoInterval> {
self.beat_positions
.windows(2)
.map(|w| TempoInterval::new(w[0], w[1]))
.collect()
}
fn generate_beat_grid(&self, period_s: f32) -> Vec<f32> {
if self.onset_times.is_empty() || period_s <= 0.0 {
return Vec::new();
}
let start = *self.onset_times.first().unwrap_or(&0.0);
let end = *self.onset_times.last().unwrap_or(&0.0);
let mut beats = Vec::new();
let mut t = start;
while t <= end + period_s * 0.5 {
beats.push(t);
t += period_s;
}
beats
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_hypothesis_bpm_roundtrip() {
let h = TempoHypothesis::new(120.0, 0.9);
assert!((h.bpm - 120.0).abs() < 1e-4);
}
#[test]
fn test_hypothesis_period_correct() {
let h = TempoHypothesis::new(60.0, 1.0);
assert!((h.period_s - 1.0).abs() < 1e-5);
}
#[test]
fn test_hypothesis_period_ms() {
let h = TempoHypothesis::new(120.0, 1.0);
assert!((h.period_ms() - 500.0).abs() < 0.1);
}
#[test]
fn test_hypothesis_score_equals_confidence() {
let h = TempoHypothesis::new(100.0, 0.7);
assert!((h.score() - 0.7).abs() < 1e-6);
}
#[test]
fn test_interval_bpm_120() {
let iv = TempoInterval::new(0.0, 0.5);
assert!((iv.bpm() - 120.0).abs() < 0.1);
}
#[test]
fn test_interval_duration() {
let iv = TempoInterval::new(1.0, 1.75);
assert!((iv.duration_s() - 0.75).abs() < 1e-6);
}
#[test]
fn test_interval_zero_duration() {
let iv = TempoInterval::new(1.0, 1.0);
assert_eq!(iv.bpm(), 0.0);
}
#[test]
fn test_tracker_no_onsets_returns_empty() {
let mut tracker = BeatTracker::new(44100.0, 512, 60.0, 200.0);
let hypotheses = tracker.detect_tempo(3);
assert!(hypotheses.is_empty());
}
#[test]
fn test_tracker_one_onset_returns_empty() {
let mut tracker = BeatTracker::new(44100.0, 512, 60.0, 200.0);
tracker.add_onset_time(0.0);
assert!(tracker.detect_tempo(3).is_empty());
}
#[test]
fn test_tracker_detects_120bpm() {
let mut tracker = BeatTracker::new(44100.0, 512, 60.0, 200.0);
for i in 0..=8 {
tracker.add_onset_time(i as f32 * 0.5);
}
let hyps = tracker.detect_tempo(3);
assert!(!hyps.is_empty());
let top_bpm = hyps[0].bpm;
assert!(
top_bpm > 110.0 && top_bpm < 130.0,
"Expected ~120 BPM, got {top_bpm}"
);
}
#[test]
fn test_beat_positions_populated_after_detect() {
let mut tracker = BeatTracker::new(44100.0, 512, 60.0, 200.0);
for i in 0..=8 {
tracker.add_onset_time(i as f32 * 0.5);
}
let _ = tracker.detect_tempo(1);
assert!(!tracker.beat_positions().is_empty());
}
#[test]
fn test_add_onset_by_frame() {
let mut tracker = BeatTracker::new(44100.0, 512, 60.0, 200.0);
tracker.add_onset(0);
tracker.add_onset(3969); assert_eq!(tracker.onset_times().len(), 2);
}
#[test]
fn test_onset_times_sorted_by_insertion() {
let mut tracker = BeatTracker::new(44100.0, 512, 60.0, 200.0);
tracker.add_onset_time(0.0);
tracker.add_onset_time(0.5);
tracker.add_onset_time(1.0);
let times = tracker.onset_times();
assert_eq!(times.len(), 3);
assert!(times[0] < times[1] && times[1] < times[2]);
}
#[test]
fn test_beat_intervals_count() {
let mut tracker = BeatTracker::new(44100.0, 512, 60.0, 200.0);
for i in 0..=10 {
tracker.add_onset_time(i as f32 * 0.5);
}
let _ = tracker.detect_tempo(1);
let ivs = tracker.beat_intervals();
assert!(!ivs.is_empty());
}
#[test]
fn test_hypothesis_confidence_sum_not_exceeds_one() {
let mut tracker = BeatTracker::new(44100.0, 512, 60.0, 200.0);
for i in 0..=16 {
tracker.add_onset_time(i as f32 * 0.5);
}
let hyps = tracker.detect_tempo(5);
let total: f32 = hyps.iter().map(|h| h.confidence).sum();
assert!(total <= 1.0 + 1e-4);
}
}