#[derive(Debug, Clone)]
pub struct BeatEvent {
pub time_ms: u64,
pub confidence: f32,
pub beat_number: u32,
}
impl BeatEvent {
#[must_use]
pub fn is_downbeat(&self, meter: u8) -> bool {
if meter == 0 {
return false;
}
self.beat_number % u32::from(meter) == 1
}
}
#[derive(Debug, Clone)]
pub struct TempoEstimate {
pub bpm: f32,
pub confidence: f32,
pub period_ms: f32,
}
impl TempoEstimate {
#[must_use]
pub fn is_reliable(&self) -> bool {
self.confidence > 0.5 && self.bpm >= 30.0 && self.bpm <= 300.0
}
}
#[must_use]
pub fn estimate_tempo_from_onsets(onsets_ms: &[u64]) -> TempoEstimate {
const BIN_SIZE: f32 = 10.0;
const MIN_IOI: f32 = 50.0;
const MAX_IOI: f32 = 2000.0;
if onsets_ms.len() < 2 {
return TempoEstimate {
bpm: 120.0,
confidence: 0.0,
period_ms: 500.0,
};
}
let iois: Vec<f32> = onsets_ms
.windows(2)
.map(|w| (w[1] - w[0]) as f32)
.filter(|&d| d > 50.0 && d < 2000.0) .collect();
if iois.is_empty() {
return TempoEstimate {
bpm: 120.0,
confidence: 0.0,
period_ms: 500.0,
};
}
let n_bins = ((MAX_IOI - MIN_IOI) / BIN_SIZE) as usize;
let mut histogram = vec![0u32; n_bins];
for &ioi in &iois {
let bin = ((ioi - MIN_IOI) / BIN_SIZE) as usize;
if bin < n_bins {
histogram[bin] += 1;
}
}
let (peak_bin, &peak_count) = histogram
.iter()
.enumerate()
.max_by_key(|(_, &v)| v)
.unwrap_or((0, &0));
let period_ms = MIN_IOI + (peak_bin as f32 + 0.5) * BIN_SIZE;
let bpm = 60_000.0 / period_ms;
let confidence = if iois.is_empty() {
0.0
} else {
(peak_count as f32 / iois.len() as f32).min(1.0)
};
TempoEstimate {
bpm,
confidence,
period_ms,
}
}
pub struct BeatTracker {
pub tempo_estimate: TempoEstimate,
}
impl BeatTracker {
#[must_use]
pub fn new(tempo_estimate: TempoEstimate) -> Self {
Self { tempo_estimate }
}
#[must_use]
pub fn track(&self, onsets_ms: &[u64]) -> Vec<BeatEvent> {
if onsets_ms.is_empty() || self.tempo_estimate.period_ms <= 0.0 {
return Vec::new();
}
let first_ms = onsets_ms[0] as f32;
let last_ms = onsets_ms[onsets_ms.len() - 1] as f32;
let mut beats = Vec::new();
let mut t = first_ms;
let mut beat_number: u32 = 1;
while t <= last_ms + self.tempo_estimate.period_ms {
beats.push(BeatEvent {
time_ms: t as u64,
confidence: self.tempo_estimate.confidence,
beat_number,
});
t += self.tempo_estimate.period_ms;
beat_number += 1;
}
beats
}
}
pub struct MeterAnalyzer;
impl MeterAnalyzer {
#[must_use]
pub fn estimate_meter(beats: &[BeatEvent]) -> u8 {
if beats.len() < 6 {
return 4;
}
let score_3: f32 = (0..beats.len())
.filter(|&i| i % 3 == 0)
.map(|i| beats[i].confidence)
.sum();
let score_4: f32 = (0..beats.len())
.filter(|&i| i % 4 == 0)
.map(|i| beats[i].confidence)
.sum();
let count_3 = beats.len().div_ceil(3).max(1) as f32;
let count_4 = beats.len().div_ceil(4).max(1) as f32;
let avg_3 = score_3 / count_3;
let avg_4 = score_4 / count_4;
if avg_3 > avg_4 {
3
} else {
4
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_beat_event_is_downbeat_4_4() {
let beat1 = BeatEvent {
time_ms: 0,
confidence: 1.0,
beat_number: 1,
};
let beat2 = BeatEvent {
time_ms: 500,
confidence: 0.8,
beat_number: 2,
};
let beat5 = BeatEvent {
time_ms: 2000,
confidence: 1.0,
beat_number: 5,
};
assert!(beat1.is_downbeat(4));
assert!(!beat2.is_downbeat(4));
assert!(beat5.is_downbeat(4));
}
#[test]
fn test_beat_event_is_downbeat_3_4() {
let beat1 = BeatEvent {
time_ms: 0,
confidence: 1.0,
beat_number: 1,
};
let beat4 = BeatEvent {
time_ms: 1500,
confidence: 1.0,
beat_number: 4,
};
let beat3 = BeatEvent {
time_ms: 1000,
confidence: 0.5,
beat_number: 3,
};
assert!(beat1.is_downbeat(3));
assert!(beat4.is_downbeat(3));
assert!(!beat3.is_downbeat(3));
}
#[test]
fn test_beat_event_zero_meter() {
let beat = BeatEvent {
time_ms: 0,
confidence: 1.0,
beat_number: 1,
};
assert!(!beat.is_downbeat(0));
}
#[test]
fn test_tempo_estimate_is_reliable() {
let reliable = TempoEstimate {
bpm: 120.0,
confidence: 0.8,
period_ms: 500.0,
};
assert!(reliable.is_reliable());
let low_conf = TempoEstimate {
bpm: 120.0,
confidence: 0.4,
period_ms: 500.0,
};
assert!(!low_conf.is_reliable());
let too_slow = TempoEstimate {
bpm: 10.0,
confidence: 0.9,
period_ms: 6000.0,
};
assert!(!too_slow.is_reliable());
let too_fast = TempoEstimate {
bpm: 400.0,
confidence: 0.9,
period_ms: 150.0,
};
assert!(!too_fast.is_reliable());
}
#[test]
fn test_tempo_estimate_is_reliable_boundary() {
let boundary = TempoEstimate {
bpm: 120.0,
confidence: 0.5,
period_ms: 500.0,
};
assert!(!boundary.is_reliable());
}
#[test]
fn test_estimate_tempo_insufficient_onsets() {
let result = estimate_tempo_from_onsets(&[100]);
assert_eq!(result.confidence, 0.0);
}
#[test]
fn test_estimate_tempo_regular_onsets() {
let onsets: Vec<u64> = (0..10).map(|i| i * 500).collect();
let result = estimate_tempo_from_onsets(&onsets);
assert!(result.bpm > 100.0 && result.bpm < 140.0);
assert!(result.confidence > 0.0);
}
#[test]
fn test_estimate_tempo_empty() {
let result = estimate_tempo_from_onsets(&[]);
assert_eq!(result.confidence, 0.0);
}
#[test]
fn test_beat_tracker_empty_onsets() {
let tempo = TempoEstimate {
bpm: 120.0,
confidence: 0.9,
period_ms: 500.0,
};
let tracker = BeatTracker::new(tempo);
assert!(tracker.track(&[]).is_empty());
}
#[test]
fn test_beat_tracker_projects_beats() {
let tempo = TempoEstimate {
bpm: 120.0,
confidence: 0.9,
period_ms: 500.0,
};
let tracker = BeatTracker::new(tempo);
let onsets = vec![0u64, 500, 1000, 1500, 2000];
let beats = tracker.track(&onsets);
assert!(!beats.is_empty());
assert_eq!(beats[0].time_ms, 0);
assert!(beats[1].time_ms >= 490 && beats[1].time_ms <= 510);
}
#[test]
fn test_beat_tracker_beat_numbers() {
let tempo = TempoEstimate {
bpm: 120.0,
confidence: 0.8,
period_ms: 500.0,
};
let tracker = BeatTracker::new(tempo);
let onsets = vec![0u64, 500, 1000, 1500];
let beats = tracker.track(&onsets);
for (i, beat) in beats.iter().enumerate() {
assert_eq!(beat.beat_number, (i + 1) as u32);
}
}
#[test]
fn test_meter_analyzer_defaults_to_4() {
let beats: Vec<BeatEvent> = (1..=4)
.map(|i| BeatEvent {
time_ms: i * 500,
confidence: 0.7,
beat_number: i as u32,
})
.collect();
let meter = MeterAnalyzer::estimate_meter(&beats);
assert_eq!(meter, 4);
}
#[test]
fn test_meter_analyzer_returns_3_or_4() {
let beats: Vec<BeatEvent> = (1..=12)
.map(|i| BeatEvent {
time_ms: i * 500,
confidence: 0.7,
beat_number: i as u32,
})
.collect();
let meter = MeterAnalyzer::estimate_meter(&beats);
assert!(meter == 3 || meter == 4);
}
}