#![allow(dead_code)]
use std::f32::consts::PI;
#[derive(Debug, Clone, PartialEq)]
pub struct Onset {
pub time_s: f32,
pub strength: f32,
}
impl Onset {
#[must_use]
pub fn new(time_s: f32, strength: f32) -> Self {
Self {
time_s,
strength: strength.clamp(0.0, 1.0),
}
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct TempoEstimate {
pub bpm: f32,
pub confidence: f32,
pub period_s: f32,
}
impl TempoEstimate {
#[must_use]
pub fn from_bpm(bpm: f32, confidence: f32) -> Self {
let period_s = if bpm > 0.0 { 60.0 / bpm } else { 0.0 };
Self {
bpm,
confidence: confidence.clamp(0.0, 1.0),
period_s,
}
}
#[must_use]
pub fn is_plausible(&self) -> bool {
self.bpm >= 40.0 && self.bpm <= 300.0
}
}
#[derive(Debug, Clone)]
pub struct BeatGrid {
pub tempo: TempoEstimate,
pub phase_s: f32,
pub beat_count: u32,
}
impl BeatGrid {
#[must_use]
pub fn new(tempo: TempoEstimate, phase_s: f32, beat_count: u32) -> Self {
Self {
tempo,
phase_s,
beat_count,
}
}
#[must_use]
pub fn beat_time(&self, n: u32) -> f32 {
self.phase_s + n as f32 * self.tempo.period_s
}
#[must_use]
pub fn all_beat_times(&self) -> Vec<f32> {
(0..self.beat_count).map(|n| self.beat_time(n)).collect()
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct TimeSignature {
pub numerator: u8,
pub denominator: u8,
pub confidence: f32,
}
impl TimeSignature {
#[must_use]
pub fn new(numerator: u8, denominator: u8, confidence: f32) -> Self {
Self {
numerator,
denominator,
confidence: confidence.clamp(0.0, 1.0),
}
}
#[must_use]
pub fn beats_per_bar(&self) -> u8 {
self.numerator
}
}
#[must_use]
pub fn estimate_tempo(onsets: &[Onset], min_bpm: f32, max_bpm: f32) -> Option<TempoEstimate> {
if onsets.len() < 2 {
return None;
}
let min_period = 60.0 / max_bpm;
let max_period = 60.0 / min_bpm;
let iois: Vec<f32> = onsets
.windows(2)
.map(|w| w[1].time_s - w[0].time_s)
.filter(|&d| d >= min_period && d <= max_period)
.collect();
if iois.is_empty() {
return None;
}
let mean_ioi: f32 = iois.iter().sum::<f32>() / iois.len() as f32;
let bpm = 60.0 / mean_ioi;
let tolerance = mean_ioi * 0.10;
let consistent = iois
.iter()
.filter(|&&d| (d - mean_ioi).abs() <= tolerance)
.count();
let confidence = consistent as f32 / iois.len() as f32;
Some(TempoEstimate::from_bpm(bpm, confidence))
}
#[must_use]
pub fn build_beat_grid(
onsets: &[Onset],
tempo: TempoEstimate,
duration_s: f32,
) -> Option<BeatGrid> {
if onsets.is_empty() || tempo.period_s <= 0.0 {
return None;
}
let steps = 64u32;
let step_size = tempo.period_s / steps as f32;
let mut best_phase = 0.0_f32;
let mut best_score = -1_i32;
for s in 0..steps {
let phase = s as f32 * step_size;
let score = onsets
.iter()
.map(|o| {
let rel = (o.time_s - phase) / tempo.period_s;
let frac = rel - rel.round();
i32::from(frac.abs() < 0.1)
})
.sum::<i32>();
if score > best_score {
best_score = score;
best_phase = phase;
}
}
let beat_count = if tempo.period_s > 0.0 {
(duration_s / tempo.period_s).ceil() as u32
} else {
0
};
Some(BeatGrid::new(tempo, best_phase, beat_count))
}
#[must_use]
pub fn detect_time_signature(onsets: &[Onset], grid: &BeatGrid) -> Vec<TimeSignature> {
let candidates: &[(u8, u8)] = &[(4, 4), (3, 4), (6, 8), (2, 4), (5, 4)];
let period = grid.tempo.period_s;
let mut results: Vec<TimeSignature> = candidates
.iter()
.map(|&(num, den)| {
let bar_len = period * f32::from(num);
let score = if bar_len > 0.0 && !onsets.is_empty() {
let aligned = onsets
.iter()
.filter(|o| {
let beat_pos = (o.time_s - grid.phase_s) / period;
let beat_in_bar = beat_pos % f32::from(num);
beat_in_bar.fract().abs() < 0.15
})
.count();
aligned as f32 / onsets.len() as f32
} else {
0.0
};
TimeSignature::new(num, den, score)
})
.collect();
results.sort_by(|a, b| {
b.confidence
.partial_cmp(&a.confidence)
.unwrap_or(std::cmp::Ordering::Equal)
});
results
}
#[must_use]
pub fn periodicity_scores(strengths: &[f32], hop_s: f32) -> Vec<(f32, f32)> {
let n = strengths.len();
if n < 2 {
return Vec::new();
}
(1..n)
.map(|lag| {
let sum: f32 = strengths[..n - lag]
.iter()
.zip(&strengths[lag..])
.map(|(&a, &b)| a * b)
.sum();
let lag_s = lag as f32 * hop_s;
(lag_s, sum)
})
.collect()
}
#[must_use]
pub fn synthetic_onset_strengths(bpm: f32, sample_rate: f32, n_frames: usize) -> Vec<f32> {
let period_frames = sample_rate * 60.0 / bpm;
(0..n_frames)
.map(|i| {
let phase = 2.0 * PI * i as f32 / period_frames;
(0.5 * (1.0 + phase.cos())).max(0.0)
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
fn make_onsets(times: &[f32]) -> Vec<Onset> {
times.iter().map(|&t| Onset::new(t, 1.0)).collect()
}
#[test]
fn test_onset_creation() {
let o = Onset::new(1.5, 0.8);
assert!((o.time_s - 1.5).abs() < 1e-6);
assert!((o.strength - 0.8).abs() < 1e-6);
}
#[test]
fn test_onset_strength_clamped() {
let o = Onset::new(0.0, 2.5);
assert!((o.strength - 1.0).abs() < 1e-6);
let o2 = Onset::new(0.0, -0.5);
assert!((o2.strength).abs() < 1e-6);
}
#[test]
fn test_tempo_estimate_period() {
let t = TempoEstimate::from_bpm(120.0, 0.9);
assert!((t.period_s - 0.5).abs() < 1e-5);
}
#[test]
fn test_tempo_estimate_plausible() {
assert!(TempoEstimate::from_bpm(120.0, 1.0).is_plausible());
assert!(!TempoEstimate::from_bpm(10.0, 1.0).is_plausible());
assert!(!TempoEstimate::from_bpm(400.0, 1.0).is_plausible());
}
#[test]
fn test_estimate_tempo_empty() {
assert!(estimate_tempo(&[], 40.0, 200.0).is_none());
}
#[test]
fn test_estimate_tempo_single_onset() {
let onsets = vec![Onset::new(0.0, 1.0)];
assert!(estimate_tempo(&onsets, 40.0, 200.0).is_none());
}
#[test]
fn test_estimate_tempo_regular_120bpm() {
let times: Vec<f32> = (0..10).map(|i| i as f32 * 0.5).collect();
let onsets = make_onsets(×);
let est = estimate_tempo(&onsets, 40.0, 200.0).expect("tempo estimation should succeed");
assert!((est.bpm - 120.0).abs() < 1.0, "bpm = {}", est.bpm);
assert!(est.confidence > 0.8);
}
#[test]
fn test_estimate_tempo_result_plausible() {
let times: Vec<f32> = (0..8).map(|i| i as f32 * 0.6).collect(); let onsets = make_onsets(×);
let est = estimate_tempo(&onsets, 40.0, 200.0).expect("tempo estimation should succeed");
assert!(est.is_plausible());
}
#[test]
fn test_build_beat_grid_none_on_empty() {
let tempo = TempoEstimate::from_bpm(120.0, 1.0);
assert!(build_beat_grid(&[], tempo, 10.0).is_none());
}
#[test]
fn test_build_beat_grid_beat_times_length() {
let times: Vec<f32> = (0..8).map(|i| i as f32 * 0.5).collect();
let onsets = make_onsets(×);
let tempo = TempoEstimate::from_bpm(120.0, 1.0);
let grid = build_beat_grid(&onsets, tempo, 4.0).expect("beat grid should succeed");
assert_eq!(grid.all_beat_times().len() as u32, grid.beat_count);
}
#[test]
fn test_build_beat_grid_beat_times_spaced_correctly() {
let times: Vec<f32> = (0..8).map(|i| i as f32 * 0.5).collect();
let onsets = make_onsets(×);
let tempo = TempoEstimate::from_bpm(120.0, 1.0);
let grid = build_beat_grid(&onsets, tempo, 4.0).expect("beat grid should succeed");
let beats = grid.all_beat_times();
for w in beats.windows(2) {
let diff = w[1] - w[0];
assert!((diff - 0.5).abs() < 1e-5, "diff = {}", diff);
}
}
#[test]
fn test_detect_time_signature_returns_candidates() {
let times: Vec<f32> = (0..16).map(|i| i as f32 * 0.5).collect();
let onsets = make_onsets(×);
let tempo = TempoEstimate::from_bpm(120.0, 1.0);
let grid = build_beat_grid(&onsets, tempo, 8.0).expect("beat grid should succeed");
let sigs = detect_time_signature(&onsets, &grid);
assert!(!sigs.is_empty());
for w in sigs.windows(2) {
assert!(w[0].confidence >= w[1].confidence);
}
}
#[test]
fn test_detect_time_signature_4_4_most_likely() {
let times: Vec<f32> = (0..32).map(|i| i as f32 * 0.5).collect();
let onsets = make_onsets(×);
let tempo = TempoEstimate::from_bpm(120.0, 1.0);
let grid = build_beat_grid(&onsets, tempo, 16.0).expect("beat grid should succeed");
let sigs = detect_time_signature(&onsets, &grid);
assert!(!sigs.is_empty());
assert!(sigs[0].numerator >= 2);
}
#[test]
fn test_time_signature_beats_per_bar() {
let sig = TimeSignature::new(3, 4, 0.9);
assert_eq!(sig.beats_per_bar(), 3);
}
#[test]
fn test_periodicity_scores_empty() {
assert!(periodicity_scores(&[], 0.01).is_empty());
}
#[test]
fn test_periodicity_scores_length() {
let data: Vec<f32> = (0..20).map(|i| i as f32).collect();
let scores = periodicity_scores(&data, 0.01);
assert_eq!(scores.len(), data.len() - 1);
}
#[test]
fn test_synthetic_onset_strengths_length() {
let s = synthetic_onset_strengths(120.0, 100.0, 50);
assert_eq!(s.len(), 50);
}
#[test]
fn test_synthetic_onset_strengths_range() {
let s = synthetic_onset_strengths(120.0, 100.0, 200);
for &v in &s {
assert!(v >= 0.0 && v <= 1.0, "value out of range: {v}");
}
}
}