use std::num::NonZeroUsize;
use non_empty_slice::{NonEmptySlice, NonEmptyVec};
use crate::{
AudioOnsetDetection, AudioSampleError, AudioSampleResult, AudioSamples, ParameterError,
operations::{onset::OnsetDetectionConfig, traits::AudioBeatTracking},
traits::StandardSample,
};
#[derive(Clone, Debug)]
#[non_exhaustive]
pub struct BeatTrackingData {
pub tempo_bpm: f64,
pub beat_times: Vec<f64>,
pub config: BeatTrackingConfig,
}
impl BeatTrackingData {
#[inline]
#[must_use]
pub const fn new(tempo_bpm: f64, beat_times: Vec<f64>, config: BeatTrackingConfig) -> Self {
Self {
tempo_bpm,
beat_times,
config,
}
}
}
impl core::fmt::Display for BeatTrackingData {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
writeln!(f, "Estimated Tempo: {:.2} BPM", self.tempo_bpm)?;
writeln!(f, "Detected Beats (s):")?;
for &time in &self.beat_times {
writeln!(f, "{time:.3}")?;
}
Ok(())
}
}
#[derive(Clone, Debug)]
#[non_exhaustive]
pub struct BeatTrackingConfig {
pub tempo_bpm: f64,
pub tolerance: Option<f64>,
pub onset_config: OnsetDetectionConfig,
}
impl BeatTrackingConfig {
#[inline]
#[must_use]
pub const fn new(
tempo_bpm: f64,
tolerance: Option<f64>,
onset_config: OnsetDetectionConfig,
) -> Self {
Self {
tempo_bpm,
tolerance,
onset_config,
}
}
}
impl<T> AudioBeatTracking for AudioSamples<'_, T>
where
T: StandardSample,
{
fn detect_beats(&self, config: &BeatTrackingConfig) -> AudioSampleResult<BeatTrackingData> {
let sr = self.sample_rate_hz();
let onset = onset_strength_envelope(self, &config.onset_config, None)?;
let beats = track_beats_core(
&onset,
config.tempo_bpm,
sr,
config.onset_config.hop_size,
config.tolerance,
)?;
Ok(BeatTrackingData {
tempo_bpm: config.tempo_bpm,
beat_times: beats,
config: config.clone(),
})
}
}
pub fn onset_strength_envelope<T>(
audio: &AudioSamples<'_, T>,
config: &OnsetDetectionConfig,
log_compression: Option<f64>,
) -> AudioSampleResult<NonEmptyVec<f64>>
where
T: StandardSample,
{
let (_times, odf) = audio.onset_detection_function(config)?;
let odf = odf.to_vec();
let window = config.window_size.unwrap_or(crate::nzu!(3)).get();
let mut smoothed = vec![0.0; odf.len()];
for (i, _) in odf.iter().enumerate() {
let start = i.saturating_sub(window);
let end: usize = (i + window + 1).min(odf.len());
let acc: f64 = odf
.iter()
.skip(start)
.take(end - start)
.fold(0.0, |acc, x| acc + *x);
smoothed[i] = acc / (end - start) as f64;
}
let compression = log_compression.unwrap_or(0.5);
let env: Vec<f64> = smoothed
.iter()
.map(|&x| (compression * x).ln_1p())
.collect();
let env = unsafe { NonEmptyVec::new_unchecked(env) };
Ok(env)
}
#[inline]
fn peak_index(slice: &[f64]) -> usize {
let mut best_i = 0usize;
let mut best_v = f64::NEG_INFINITY;
for i in 0..slice.len() {
let v = unsafe { *slice.get_unchecked(i) };
if v > best_v {
best_v = v;
best_i = i;
}
}
best_i
}
#[inline]
pub fn track_beats_core(
onset: &NonEmptySlice<f64>,
tempo_bpm: f64,
sample_rate: f64,
hop_size: NonZeroUsize,
tolerance_seconds: Option<f64>,
) -> AudioSampleResult<Vec<f64>> {
if tempo_bpm <= 0.0 {
return Err(AudioSampleError::Parameter(ParameterError::invalid_value(
"tempo_bpm",
tempo_bpm,
)));
}
let hop_time = hop_size.get() as f64 / sample_rate;
let ibi_seconds = 60.0 / tempo_bpm;
let ibi_frames = (ibi_seconds / hop_time).round() as isize;
if ibi_frames <= 0 {
return Err(AudioSampleError::Parameter(ParameterError::invalid_value(
"tempo_bpm",
"Inter-beat interval too small",
)));
}
let tol_frames = tolerance_seconds
.map_or_else(
|| (ibi_frames as f64 * 0.1).round() as isize,
|t| (t / hop_time).round() as isize,
)
.max(1);
let len = onset.len().get() as isize;
let mut start = 0isize;
let mut best_v = f64::NEG_INFINITY;
for i in 0..onset.len().get() {
let v = unsafe { *onset.get_unchecked(i) };
if v > best_v {
best_v = v;
start = i as isize;
}
}
let est_beats = (len / ibi_frames).max(1) as usize;
let mut beat_frames = Vec::with_capacity(est_beats);
beat_frames.push(start);
let mut center = start;
while center + ibi_frames < len {
center += ibi_frames;
let lo = (center - tol_frames).max(0) as usize;
let hi = (center + tol_frames).min(len) as usize;
let rel = peak_index(&onset[lo..hi]) as isize;
beat_frames.push(lo as isize + rel);
}
let mut center = start;
while center - ibi_frames >= 0 {
center -= ibi_frames;
let lo = (center - tol_frames).max(0) as usize;
let hi = (center + tol_frames).min(len) as usize;
let rel = peak_index(&onset[lo..hi]) as isize;
beat_frames.push(lo as isize + rel);
}
let mut times = Vec::with_capacity(beat_frames.len());
for f in beat_frames {
times.push(f as f64 * hop_time);
}
Ok(times)
}
#[cfg(test)]
mod tests {
use super::*;
use non_empty_slice::NonEmptyVec;
use proptest::prelude::*;
fn synthetic_onset(len: usize) -> NonEmptyVec<f64> {
let mut v = vec![0.0; len];
for i in (0..len).step_by(20.max(1)) {
v[i] = 1.0;
}
let v = NonEmptyVec::new(v).unwrap();
v
}
proptest! {
#[test]
fn beat_times_are_finite_and_non_negative(
len in 64usize..2048,
tempo in 40.0f64..240.0,
sr in 8_000.0f64..96_000.0,
hop in 1usize..2048,
) {
let onset = synthetic_onset(len);
let hop = NonZeroUsize::new(hop).unwrap();
let beats = track_beats_core(
&onset,
tempo,
sr,
hop,
None,
).unwrap();
for &t in &beats {
prop_assert!(t.is_finite());
prop_assert!(t >= 0.0);
}
}
#[test]
fn beat_times_within_signal_bounds(
len in 64usize..4096,
tempo in 40.0f64..240.0,
sr in 8_000.0f64..96_000.0,
hop in 1usize..1024,
) {
let onset = synthetic_onset(len);
let duration = (len as f64 * hop as f64) / sr;
let hop = NonZeroUsize::new(hop).unwrap();
let beats = track_beats_core(
&onset,
tempo,
sr,
hop,
None,
).unwrap();
for &t in &beats {
prop_assert!(t <= duration + 1e-6);
}
}
#[test]
fn first_beat_is_global_peak(
len in 128usize..2048,
tempo in 60.0f64..180.0,
sr in 16_000.0f64..48_000.0,
hop in 1usize..1024,
) {
let hop = NonZeroUsize::new(hop).unwrap();
let onset = vec![0.0; len];
let mut onset = NonEmptyVec::new(onset).unwrap();
let peak_idx = len / 3;
onset[peak_idx] = 10.0;
let beats = track_beats_core(
&onset,
tempo,
sr,
hop,
None,
).unwrap();
let first_frame = (beats[0] * sr / hop.get() as f64).round() as usize;
prop_assert_eq!(first_frame, peak_idx);
}
#[test]
fn insertion_order_preserves_forward_then_backward_structure(
len in 256usize..4096,
tempo in 60.0f64..180.0,
sr in 16_000.0f64..48_000.0,
hop in 1usize..512,
) {
let onset = synthetic_onset(len);
let hop = NonZeroUsize::new(hop).unwrap();
let beats = track_beats_core(
&onset,
tempo,
sr,
hop,
None,
).unwrap();
if beats.len() >= 2 {
prop_assert!(beats[1] >= beats[0]);
}
if beats.len() >= 3 {
let first = beats[0];
let mut seen_forward = false;
let mut seen_backward = false;
for i in 1..beats.len() {
if beats[i] >= first {
prop_assert!(!seen_backward, "Forward beat found after backward beat");
seen_forward = true;
} else {
seen_backward = true;
}
}
if seen_forward && seen_backward {
let first_backward_idx = beats.iter().position(|&t| t < first).unwrap();
for i in 1..first_backward_idx {
prop_assert!(beats[i] >= first, "Beat at index {} should be forward", i);
}
for i in first_backward_idx..beats.len() {
prop_assert!(beats[i] < first, "Beat at index {} should be backward", i);
}
}
}
}
}
}