use std::error::Error;
use std::fmt;
use std::fs::File;
use std::io::ErrorKind;
use std::path::Path;
use symphonia::core::audio::{AudioBufferRef, SampleBuffer};
use symphonia::core::codecs::DecoderOptions;
use symphonia::core::errors::Error as SymphoniaError;
use symphonia::core::formats::FormatOptions;
use symphonia::core::io::MediaSourceStream;
use symphonia::core::meta::MetadataOptions;
use symphonia::core::probe::Hint;
use symphonia::default::{get_codecs, get_probe};
#[derive(Debug, Clone, PartialEq)]
pub struct TapTempoAnalysis {
pub average_interval_ms: f64,
pub bpm: f64,
pub rounded_bpm: u32,
}
#[derive(Debug, Clone, PartialEq)]
pub struct AudioFileAnalysis {
pub bpm: f64,
pub rounded_bpm: u32,
pub normalized_bpm: f64,
pub confidence: f64,
pub duration_seconds: f64,
pub analyzed_seconds: f64,
pub sample_rate: u32,
}
#[derive(Debug, Clone, PartialEq)]
pub enum TapTempoError {
NotEnoughIntervals,
NonPositiveInterval,
NonPositiveBpm,
NonPositiveMilliseconds,
InvalidBarLength,
InvalidRange,
Io(String),
Decode(String),
NoAudioTrack,
MissingSampleRate,
AudioTooShort,
UnsupportedPath,
DetectionFailed,
}
impl fmt::Display for TapTempoError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
TapTempoError::NotEnoughIntervals => {
write!(f, "at least two tap intervals are required")
}
TapTempoError::NonPositiveInterval => {
write!(f, "tap intervals must be positive numbers")
}
TapTempoError::NonPositiveBpm => write!(f, "BPM must be greater than 0"),
TapTempoError::NonPositiveMilliseconds => {
write!(f, "milliseconds must be greater than 0")
}
TapTempoError::InvalidBarLength => {
write!(f, "beats per bar must be greater than 0")
}
TapTempoError::InvalidRange => write!(f, "min BPM must be lower than max BPM"),
TapTempoError::Io(message) => write!(f, "{message}"),
TapTempoError::Decode(message) => write!(f, "{message}"),
TapTempoError::NoAudioTrack => write!(f, "no default audio track found in file"),
TapTempoError::MissingSampleRate => {
write!(f, "audio track is missing sample rate metadata")
}
TapTempoError::AudioTooShort => {
write!(f, "audio file is too short for reliable BPM analysis")
}
TapTempoError::UnsupportedPath => write!(f, "path must point to a regular audio file"),
TapTempoError::DetectionFailed => {
write!(f, "could not derive a reliable BPM candidate")
}
}
}
}
impl Error for TapTempoError {}
fn round_to(value: f64, precision: u32) -> f64 {
let factor = 10_f64.powi(precision as i32);
(value * factor).round() / factor
}
pub mod tap {
use super::{TapTempoAnalysis, TapTempoError, round_to};
pub fn analyze_intervals(intervals_ms: &[f64]) -> Result<TapTempoAnalysis, TapTempoError> {
if intervals_ms.len() < 2 {
return Err(TapTempoError::NotEnoughIntervals);
}
if intervals_ms.iter().any(|value| *value <= 0.0) {
return Err(TapTempoError::NonPositiveInterval);
}
let average_interval_ms = intervals_ms.iter().sum::<f64>() / intervals_ms.len() as f64;
let bpm = round_to(60_000.0 / average_interval_ms, 3);
Ok(TapTempoAnalysis {
average_interval_ms: round_to(average_interval_ms, 3),
bpm,
rounded_bpm: bpm.round() as u32,
})
}
pub fn bpm_from_intervals(intervals_ms: &[f64]) -> Result<f64, TapTempoError> {
analyze_intervals(intervals_ms).map(|analysis| analysis.bpm)
}
}
pub mod convert {
use super::{TapTempoError, round_to};
pub fn bpm_to_ms_per_beat(bpm: f64) -> Result<f64, TapTempoError> {
if bpm <= 0.0 {
return Err(TapTempoError::NonPositiveBpm);
}
Ok(round_to(60_000.0 / bpm, 3))
}
pub fn bpm_to_ms_per_bar(bpm: f64, beats_per_bar: u32) -> Result<f64, TapTempoError> {
if beats_per_bar == 0 {
return Err(TapTempoError::InvalidBarLength);
}
Ok(round_to(bpm_to_ms_per_beat(bpm)? * beats_per_bar as f64, 3))
}
pub fn ms_per_beat_to_bpm(milliseconds: f64) -> Result<f64, TapTempoError> {
if milliseconds <= 0.0 {
return Err(TapTempoError::NonPositiveMilliseconds);
}
Ok(round_to(60_000.0 / milliseconds, 3))
}
pub fn ms_per_bar_to_bpm(milliseconds: f64, beats_per_bar: u32) -> Result<f64, TapTempoError> {
if milliseconds <= 0.0 {
return Err(TapTempoError::NonPositiveMilliseconds);
}
if beats_per_bar == 0 {
return Err(TapTempoError::InvalidBarLength);
}
ms_per_beat_to_bpm(milliseconds / beats_per_bar as f64)
}
}
pub mod range {
use super::{TapTempoError, round_to};
pub fn normalize(bpm: f64, min: f64, max: f64) -> Result<f64, TapTempoError> {
if bpm <= 0.0 {
return Err(TapTempoError::NonPositiveBpm);
}
if min <= 0.0 || max <= 0.0 || min >= max {
return Err(TapTempoError::InvalidRange);
}
let mut normalized = bpm;
while normalized < min {
normalized *= 2.0;
}
while normalized > max {
normalized /= 2.0;
}
Ok(round_to(normalized, 3))
}
pub fn is_within(bpm: f64, min: f64, max: f64) -> Result<bool, TapTempoError> {
if bpm <= 0.0 {
return Err(TapTempoError::NonPositiveBpm);
}
if min <= 0.0 || max <= 0.0 || min >= max {
return Err(TapTempoError::InvalidRange);
}
Ok(bpm >= min && bpm <= max)
}
pub fn half_time(bpm: f64) -> Result<f64, TapTempoError> {
if bpm <= 0.0 {
return Err(TapTempoError::NonPositiveBpm);
}
Ok(round_to(bpm / 2.0, 3))
}
pub fn double_time(bpm: f64) -> Result<f64, TapTempoError> {
if bpm <= 0.0 {
return Err(TapTempoError::NonPositiveBpm);
}
Ok(round_to(bpm * 2.0, 3))
}
}
pub mod file {
use super::{
AudioBufferRef, AudioFileAnalysis, DecoderOptions, ErrorKind, File, FormatOptions, Hint,
MediaSourceStream, MetadataOptions, Path, SampleBuffer, SymphoniaError, TapTempoError,
get_codecs, get_probe, range, round_to,
};
const FRAME_SIZE: usize = 1024;
const HOP_SIZE: usize = 512;
const MIN_ANALYSIS_SECONDS: f64 = 3.0;
const MAX_ANALYSIS_SECONDS: usize = 180;
pub fn analyze_path<P: AsRef<Path>>(
path: P,
min_bpm: f64,
max_bpm: f64,
) -> Result<AudioFileAnalysis, TapTempoError> {
let path = path.as_ref();
if !path.is_file() {
return Err(TapTempoError::UnsupportedPath);
}
let source = File::open(path).map_err(|error| {
TapTempoError::Io(format!(
"failed to open audio file {}: {error}",
path.display()
))
})?;
let media_source = MediaSourceStream::new(Box::new(source), Default::default());
let mut hint = Hint::new();
if let Some(extension) = path.extension().and_then(|value| value.to_str()) {
hint.with_extension(extension);
}
let probed = get_probe()
.format(
&hint,
media_source,
&FormatOptions::default(),
&MetadataOptions::default(),
)
.map_err(|error| {
TapTempoError::Decode(format!(
"failed to probe audio file {}: {error}",
path.display()
))
})?;
let mut format = probed.format;
let (track_id, codec_params) = {
let track = format.default_track().ok_or(TapTempoError::NoAudioTrack)?;
(track.id, track.codec_params.clone())
};
let sample_rate = codec_params
.sample_rate
.ok_or(TapTempoError::MissingSampleRate)?;
let max_samples = sample_rate as usize * MAX_ANALYSIS_SECONDS;
let mut decoder = get_codecs()
.make(&codec_params, &DecoderOptions::default())
.map_err(|error| {
TapTempoError::Decode(format!(
"failed to initialize audio decoder for {}: {error}",
path.display()
))
})?;
let mut samples = Vec::new();
loop {
let packet = match format.next_packet() {
Ok(packet) => packet,
Err(SymphoniaError::IoError(error)) if error.kind() == ErrorKind::UnexpectedEof => {
break;
}
Err(error) => {
return Err(TapTempoError::Decode(format!(
"failed to read packet from {}: {error}",
path.display()
)));
}
};
if packet.track_id() != track_id {
continue;
}
match decoder.decode(&packet) {
Ok(buffer) => append_mono_samples(buffer, &mut samples),
Err(SymphoniaError::DecodeError(_)) => continue,
Err(SymphoniaError::IoError(error)) if error.kind() == ErrorKind::UnexpectedEof => {
break;
}
Err(error) => {
return Err(TapTempoError::Decode(format!(
"failed to decode audio data from {}: {error}",
path.display()
)));
}
}
if samples.len() >= max_samples {
samples.truncate(max_samples);
break;
}
}
analyze_samples(&samples, sample_rate, min_bpm, max_bpm)
}
pub fn analyze_samples(
samples: &[f32],
sample_rate: u32,
min_bpm: f64,
max_bpm: f64,
) -> Result<AudioFileAnalysis, TapTempoError> {
if sample_rate == 0 {
return Err(TapTempoError::MissingSampleRate);
}
if min_bpm <= 0.0 || max_bpm <= 0.0 || min_bpm >= max_bpm {
return Err(TapTempoError::InvalidRange);
}
let min_samples = (sample_rate as f64 * MIN_ANALYSIS_SECONDS) as usize;
if samples.len() < min_samples || samples.len() < FRAME_SIZE * 4 {
return Err(TapTempoError::AudioTooShort);
}
let duration_seconds = round_to(samples.len() as f64 / sample_rate as f64, 3);
let envelope = build_onset_envelope(samples);
if envelope.len() < 16 {
return Err(TapTempoError::AudioTooShort);
}
let frame_rate = sample_rate as f64 / HOP_SIZE as f64;
let min_lag = ((60.0 * frame_rate) / max_bpm).floor().max(1.0) as usize;
let max_lag = ((60.0 * frame_rate) / min_bpm).ceil() as usize;
if max_lag >= envelope.len() {
return Err(TapTempoError::AudioTooShort);
}
let mut best_lag = None;
let mut best_score = 0.0;
let mut second_score = 0.0;
for lag in min_lag..=max_lag {
let score = autocorrelation_score(&envelope, lag);
if score > best_score {
second_score = best_score;
best_score = score;
best_lag = Some(lag);
} else if score > second_score {
second_score = score;
}
}
let best_lag = best_lag.ok_or(TapTempoError::DetectionFailed)?;
if best_score <= 0.0 {
return Err(TapTempoError::DetectionFailed);
}
let bpm = round_to((60.0 * frame_rate) / best_lag as f64, 3);
let normalized_bpm = range::normalize(bpm, min_bpm, max_bpm)?;
let denominator = if second_score > 0.0 {
best_score + second_score
} else {
best_score
};
let confidence = if denominator > 0.0 {
round_to((best_score / denominator).clamp(0.0, 1.0), 3)
} else {
0.0
};
Ok(AudioFileAnalysis {
bpm,
rounded_bpm: bpm.round() as u32,
normalized_bpm,
confidence,
duration_seconds,
analyzed_seconds: duration_seconds,
sample_rate,
})
}
fn build_onset_envelope(samples: &[f32]) -> Vec<f64> {
let mut energies = Vec::new();
let mut index = 0;
while index + FRAME_SIZE <= samples.len() {
let frame = &samples[index..index + FRAME_SIZE];
let energy =
frame.iter().map(|sample| sample.abs() as f64).sum::<f64>() / FRAME_SIZE as f64;
energies.push(energy);
index += HOP_SIZE;
}
if energies.is_empty() {
return Vec::new();
}
let mut envelope = Vec::with_capacity(energies.len());
envelope.push(0.0);
for pair in energies.windows(2) {
let diff = pair[1] - pair[0];
envelope.push(diff.max(0.0));
}
let mean = envelope.iter().sum::<f64>() / envelope.len() as f64;
envelope
.into_iter()
.map(|value| (value - mean).max(0.0))
.collect()
}
fn autocorrelation_score(envelope: &[f64], lag: usize) -> f64 {
let overlap = envelope.len().saturating_sub(lag);
if overlap == 0 {
return 0.0;
}
let mut score = 0.0;
for index in lag..envelope.len() {
score += envelope[index] * envelope[index - lag];
}
score / overlap as f64
}
fn append_mono_samples(decoded: AudioBufferRef<'_>, samples: &mut Vec<f32>) {
let spec = *decoded.spec();
let channels = spec.channels.count();
let frames = decoded.frames() as u64;
let mut sample_buffer = SampleBuffer::<f32>::new(frames, spec);
sample_buffer.copy_interleaved_ref(decoded);
for frame in sample_buffer.samples().chunks(channels) {
let mono = frame.iter().copied().sum::<f32>() / channels as f32;
samples.push(mono);
}
}
}
#[cfg(test)]
mod tests {
use super::{convert, file, range, tap};
#[test]
fn converts_bpm_to_milliseconds_per_beat() {
assert_eq!(convert::bpm_to_ms_per_beat(120.0).unwrap(), 500.0);
}
#[test]
fn converts_bar_milliseconds_to_bpm() {
assert_eq!(convert::ms_per_bar_to_bpm(1875.0, 4).unwrap(), 128.0);
}
#[test]
fn normalizes_into_practical_range() {
assert_eq!(range::normalize(72.0, 90.0, 180.0).unwrap(), 144.0);
}
#[test]
fn returns_half_time() {
assert_eq!(range::half_time(174.0).unwrap(), 87.0);
}
#[test]
fn returns_double_time() {
assert_eq!(range::double_time(87.5).unwrap(), 175.0);
}
#[test]
fn analyzes_stable_tap_intervals() {
let analysis = tap::analyze_intervals(&[500.0, 500.0, 500.0, 500.0]).unwrap();
assert_eq!(analysis.average_interval_ms, 500.0);
assert_eq!(analysis.bpm, 120.0);
assert_eq!(analysis.rounded_bpm, 120);
}
#[test]
fn analyzes_slightly_variable_tap_intervals() {
let analysis = tap::analyze_intervals(&[500.0, 480.0, 495.0, 505.0]).unwrap();
assert_eq!(analysis.average_interval_ms, 495.0);
assert_eq!(analysis.bpm, 121.212);
assert_eq!(analysis.rounded_bpm, 121);
}
#[test]
fn rejects_short_tap_sequences() {
let error = tap::analyze_intervals(&[500.0]).unwrap_err();
assert_eq!(error.to_string(), "at least two tap intervals are required");
}
#[test]
fn analyzes_synthetic_audio_samples() {
let sample_rate = 44_100;
let samples = synthetic_click_track(120.0, sample_rate, 8.0);
let analysis = file::analyze_samples(&samples, sample_rate, 70.0, 180.0).unwrap();
assert!((analysis.bpm - 120.0).abs() < 1.0);
assert_eq!(analysis.rounded_bpm, 120);
assert!((analysis.normalized_bpm - 120.0).abs() < 1.0);
assert!(analysis.confidence > 0.5);
}
fn synthetic_click_track(bpm: f64, sample_rate: u32, duration_seconds: f64) -> Vec<f32> {
let total_samples = (sample_rate as f64 * duration_seconds) as usize;
let mut samples = vec![0.0; total_samples];
let interval = (sample_rate as f64 * 60.0 / bpm) as usize;
let pulse_len = 512usize;
let mut beat_start = 0usize;
while beat_start < total_samples {
for offset in 0..pulse_len {
let index = beat_start + offset;
if index >= total_samples {
break;
}
let decay = 1.0 - offset as f32 / pulse_len as f32;
samples[index] = decay * 0.9;
}
beat_start += interval;
}
samples
}
}