use crate::{MirError, MirResult};
#[allow(dead_code)]
pub struct BpmEstimator {
sample_rate: f32,
min_bpm: f32,
max_bpm: f32,
}
impl BpmEstimator {
#[must_use]
pub fn new(sample_rate: f32, min_bpm: f32, max_bpm: f32) -> Self {
Self {
sample_rate,
min_bpm,
max_bpm,
}
}
#[allow(clippy::cast_precision_loss)]
pub fn estimate_from_intervals(&self, intervals: &[f32]) -> MirResult<(f32, f32)> {
if intervals.is_empty() {
return Err(MirError::InsufficientData(
"No intervals provided".to_string(),
));
}
let bpms: Vec<f32> = intervals
.iter()
.filter(|&&i| i > 0.0)
.map(|&interval| 60.0 / interval)
.filter(|&bpm| bpm >= self.min_bpm && bpm <= self.max_bpm)
.collect();
if bpms.is_empty() {
return Err(MirError::AnalysisFailed(
"No valid BPM estimates".to_string(),
));
}
let (bpm, confidence) = self.find_dominant_bpm(&bpms);
Ok((bpm, confidence))
}
fn find_dominant_bpm(&self, bpms: &[f32]) -> (f32, f32) {
let bin_size = 2.0; let num_bins = ((self.max_bpm - self.min_bpm) / bin_size).ceil() as usize;
let mut bins = vec![0_usize; num_bins];
let mut bin_sums = vec![0.0_f32; num_bins];
for &bpm in bpms {
let bin = ((bpm - self.min_bpm) / bin_size).floor() as usize;
if bin < num_bins {
bins[bin] += 1;
bin_sums[bin] += bpm;
}
}
let (max_bin, &max_count) = bins
.iter()
.enumerate()
.max_by_key(|(_, &count)| count)
.unwrap_or((0, &0));
if max_count == 0 {
return (self.min_bpm, 0.0);
}
let dominant_bpm = bin_sums[max_bin] / max_count as f32;
let confidence = max_count as f32 / bpms.len() as f32;
(dominant_bpm, confidence)
}
pub fn refine_with_phase(
&self,
initial_bpm: f32,
onset_times: &[f32],
) -> MirResult<(f32, f32)> {
if onset_times.len() < 2 {
return Err(MirError::InsufficientData(
"Need at least 2 onset times".to_string(),
));
}
let beat_period = 60.0 / initial_bpm;
let best_bpm = initial_bpm;
let mut best_score = 0.0;
for phase_offset in (0..100).map(|i| i as f32 * beat_period / 100.0) {
let mut score = 0.0;
for &onset_time in onset_times {
let beat_number = ((onset_time - phase_offset) / beat_period).round();
let beat_time = phase_offset + beat_number * beat_period;
let error = (onset_time - beat_time).abs();
let tolerance: f32 = 0.070; score += (-error.powi(2) / (2.0 * tolerance.powi(2))).exp();
}
if score > best_score {
best_score = score;
}
}
let confidence = (best_score / onset_times.len() as f32).clamp(0.0, 1.0);
Ok((best_bpm, confidence))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bpm_estimator_creation() {
let estimator = BpmEstimator::new(44100.0, 60.0, 200.0);
assert_eq!(estimator.sample_rate, 44100.0);
}
#[test]
fn test_estimate_from_intervals() {
let estimator = BpmEstimator::new(44100.0, 60.0, 200.0);
let intervals = vec![0.5, 0.5, 0.5, 0.5];
let result = estimator.estimate_from_intervals(&intervals);
assert!(result.is_ok());
let (bpm, confidence) = result.expect("should succeed in test");
assert!((bpm - 120.0).abs() < 5.0);
assert!(confidence > 0.5);
}
#[test]
fn test_estimate_empty_intervals() {
let estimator = BpmEstimator::new(44100.0, 60.0, 200.0);
let result = estimator.estimate_from_intervals(&[]);
assert!(result.is_err());
}
}