pub mod complex;
pub mod filters;
pub mod flux;
pub mod kernels;
use std::{num::NonZeroUsize, str::FromStr};
use ndarray::Array2;
use non_empty_slice::{NonEmptyVec, non_empty_vec};
use spectrograms::{CqtParams, SpectrogramParams, StftParams};
use crate::{
AudioSampleError, AudioSampleResult, AudioSamples, AudioTransforms, AudioTypeConversion,
ParameterError, StandardSample,
operations::{
onset::{
complex::{combine_complex_odf, magnitude_difference, phase_deviation},
filters::{log_compress_inplace, median_filter, rectify_inplace},
flux::{complex_flux, energy_flux, magnitude_flux, rectified_complex_flux},
kernels::{apply_adaptive_threshold, energy_odf},
},
peak_picking::pick_peaks,
traits::AudioOnsetDetection,
types::PeakPickingConfig,
},
};
#[derive(Debug, Clone, PartialEq)]
#[non_exhaustive]
pub struct OnsetDetectionConfig {
pub cqt_params: CqtParams,
pub hop_size: NonZeroUsize,
pub window_size: Option<NonZeroUsize>,
pub threshold: f64,
pub min_onset_interval_secs: f64,
pub pre_emphasis: f64,
pub adaptive_threshold: bool,
pub median_filter_length: NonZeroUsize,
pub adaptive_threshold_multiplier: f64,
pub peak_picking: PeakPickingConfig,
}
impl OnsetDetectionConfig {
#[inline]
#[must_use]
pub const fn new(
cqt_params: CqtParams,
hop_size: NonZeroUsize,
window_size: Option<NonZeroUsize>,
threshold: f64,
min_onset_interval_secs: f64,
pre_emphasis: f64,
adaptive_threshold: bool,
median_filter_length: NonZeroUsize,
adaptive_threshold_multiplier: f64,
peak_picking: PeakPickingConfig,
) -> Self {
Self {
cqt_params,
hop_size,
window_size,
threshold,
min_onset_interval_secs,
pre_emphasis,
adaptive_threshold,
median_filter_length,
adaptive_threshold_multiplier,
peak_picking,
}
}
#[inline]
#[must_use]
pub fn effective_window_size(&self, sample_rate: f64) -> NonZeroUsize {
self.window_size.unwrap_or_else(|| {
let min_period = sample_rate / self.cqt_params.f_min;
unsafe { NonZeroUsize::new_unchecked((min_period * 4.0) as usize) }
})
}
#[inline]
#[must_use]
pub fn frame_to_seconds(&self, frame_index: usize, sample_rate: f64) -> f64 {
(frame_index as f64 * self.hop_size.get() as f64) / sample_rate
}
#[inline]
#[must_use]
pub fn percussive() -> Self {
let cqt_config = CqtParams::percussive();
let hop_size = crate::nzu!(256);
let window = None;
let threshold = 0.5;
let min_onset_interval = 0.03; let pre_emphasis = 0.3;
let adaptive_threshold = true;
let median_filter_length = crate::nzu!(3);
let adaptive_threshold_multiplier = 2.0;
let peak_picking = PeakPickingConfig::drums();
Self::new(
cqt_config,
hop_size,
window,
threshold,
min_onset_interval,
pre_emphasis,
adaptive_threshold,
median_filter_length,
adaptive_threshold_multiplier,
peak_picking,
)
}
#[inline]
#[must_use]
pub fn musical() -> Self {
let cqt_config = CqtParams::musical();
let hop_size = crate::nzu!(512);
let window = None;
let threshold = 0.25;
let min_onset_interval = 0.1; let pre_emphasis = 0.1;
let adaptive_threshold = true;
let median_filter_length = crate::nzu!(7);
let adaptive_threshold_multiplier = 1.2;
let peak_picking = PeakPickingConfig::music();
Self::new(
cqt_config,
hop_size,
window,
threshold,
min_onset_interval,
pre_emphasis,
adaptive_threshold,
median_filter_length,
adaptive_threshold_multiplier,
peak_picking,
)
}
#[inline]
#[must_use]
pub fn speech() -> Self {
let cqt_config = CqtParams::onset_detection();
let hop_size = crate::nzu!(256);
let window = None;
let threshold = 0.2;
let min_onset_interval = 0.08; let pre_emphasis = 0.05;
let adaptive_threshold = true;
let median_filter_length = crate::nzu!(9);
let adaptive_threshold_multiplier = 1.1;
let peak_picking = PeakPickingConfig::speech();
Self::new(
cqt_config,
hop_size,
window,
threshold,
min_onset_interval,
pre_emphasis,
adaptive_threshold,
median_filter_length,
adaptive_threshold_multiplier,
peak_picking,
)
}
}
impl Default for OnsetDetectionConfig {
#[inline]
fn default() -> Self {
Self::new(
CqtParams::onset_detection(),
crate::nzu!(512),
None,
0.3,
0.1,
0.1,
true,
crate::nzu!(5),
1.5,
PeakPickingConfig::default(),
)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum SpectralFluxMethod {
Energy,
Magnitude,
Complex,
RectifiedComplex,
}
impl FromStr for SpectralFluxMethod {
type Err = AudioSampleError;
#[inline]
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"energy" => Ok(Self::Energy),
"magnitude" => Ok(Self::Magnitude),
"complex" => Ok(Self::Complex),
"rectified_complex" => Ok(Self::RectifiedComplex),
_ => Err(AudioSampleError::Parameter(ParameterError::invalid_value(
"spectral_flux_method",
s,
))),
}
}
}
#[derive(Debug, Clone, PartialEq)]
#[non_exhaustive]
pub struct SpectralFluxConfig {
pub cqt_params: CqtParams,
pub hop_size: NonZeroUsize,
pub window_size: Option<NonZeroUsize>,
pub flux_method: SpectralFluxMethod,
pub peak_picking: PeakPickingConfig,
pub rectify: bool,
pub log_compression: f64,
}
impl SpectralFluxConfig {
#[inline]
#[must_use]
pub const fn new(
cqt_params: CqtParams,
hop_size: NonZeroUsize,
window_size: Option<NonZeroUsize>,
flux_method: SpectralFluxMethod,
peak_picking: PeakPickingConfig,
rectify: bool,
log_compression: f64,
) -> Self {
Self {
cqt_params,
hop_size,
window_size,
flux_method,
peak_picking,
rectify,
log_compression,
}
}
#[inline]
#[must_use]
pub fn percussive() -> Self {
Self {
cqt_params: CqtParams::percussive(),
hop_size: crate::nzu!(256),
window_size: None,
flux_method: SpectralFluxMethod::Energy,
peak_picking: PeakPickingConfig::drums(),
rectify: true,
log_compression: 1000.0,
}
}
#[inline]
#[must_use]
pub fn musical() -> Self {
Self {
cqt_params: CqtParams::onset_detection(),
hop_size: crate::nzu!(512),
window_size: None,
flux_method: SpectralFluxMethod::Magnitude,
peak_picking: PeakPickingConfig::music(),
rectify: true,
log_compression: 100.0,
}
}
#[inline]
#[must_use]
pub fn complex() -> Self {
Self {
cqt_params: CqtParams::onset_detection(),
hop_size: crate::nzu!(512),
window_size: None,
flux_method: SpectralFluxMethod::Complex,
peak_picking: PeakPickingConfig::default(),
rectify: false,
log_compression: 100.0,
}
}
#[inline]
pub fn validate(&self) -> AudioSampleResult<()> {
self.peak_picking.validate()?;
if self.log_compression < 0.0 {
return Err(AudioSampleError::Parameter(ParameterError::invalid_value(
"parameter",
"Log compression factor must be non-negative",
)));
}
Ok(())
}
}
#[derive(Debug, Clone, PartialEq)]
#[non_exhaustive]
pub struct ComplexOnsetConfig {
pub cqt_config: CqtParams,
pub hop_size: NonZeroUsize,
pub window_size: Option<NonZeroUsize>,
pub peak_picking: PeakPickingConfig,
pub magnitude_weight: f64,
pub phase_weight: f64,
pub magnitude_rectify: bool,
pub phase_rectify: bool,
pub log_compression: f64,
}
impl ComplexOnsetConfig {
#[inline]
#[must_use]
pub const fn new(
cqt_config: CqtParams,
hop_size: NonZeroUsize,
window_size: Option<NonZeroUsize>,
peak_picking: PeakPickingConfig,
magnitude_weight: f64,
phase_weight: f64,
magnitude_rectify: bool,
phase_rectify: bool,
log_compression: f64,
) -> Self {
Self {
cqt_config,
hop_size,
window_size,
peak_picking,
magnitude_weight,
phase_weight,
magnitude_rectify,
phase_rectify,
log_compression,
}
}
#[inline]
#[must_use]
pub fn percussive() -> Self {
Self {
cqt_config: CqtParams::onset_detection(),
hop_size: crate::nzu!(256),
window_size: None,
peak_picking: PeakPickingConfig::drums(),
magnitude_weight: 0.8,
phase_weight: 0.2,
magnitude_rectify: true,
phase_rectify: true,
log_compression: 1000.0,
}
}
#[inline]
#[must_use]
pub fn musical() -> Self {
Self {
cqt_config: CqtParams::onset_detection(),
hop_size: crate::nzu!(512),
window_size: None,
peak_picking: PeakPickingConfig::music(),
magnitude_weight: 0.6,
phase_weight: 0.4,
magnitude_rectify: true,
phase_rectify: true,
log_compression: 100.0,
}
}
#[inline]
#[must_use]
pub fn speech() -> Self {
Self {
cqt_config: CqtParams::onset_detection(),
hop_size: crate::nzu!(256),
window_size: None,
peak_picking: PeakPickingConfig::speech(),
magnitude_weight: 0.5,
phase_weight: 0.5,
magnitude_rectify: true,
phase_rectify: false,
log_compression: 50.0,
}
}
#[inline]
pub const fn set_weights(&mut self, magnitude_weight: f64, phase_weight: f64) {
self.magnitude_weight = magnitude_weight.clamp(0.0, 1.0);
self.phase_weight = phase_weight.clamp(0.0, 1.0);
}
#[inline]
pub fn validate(&self) -> AudioSampleResult<()> {
self.peak_picking.validate()?;
if self.magnitude_weight < 0.0 || self.magnitude_weight > 1.0 {
return Err(AudioSampleError::Parameter(ParameterError::invalid_value(
"parameter",
"Magnitude weight must be between 0.0 and 1.0",
)));
}
if self.phase_weight < 0.0 || self.phase_weight > 1.0 {
return Err(AudioSampleError::Parameter(ParameterError::invalid_value(
"parameter",
"Phase weight must be between 0.0 and 1.0",
)));
}
if self.magnitude_weight == 0.0 && self.phase_weight == 0.0 {
return Err(AudioSampleError::Parameter(ParameterError::invalid_value(
"parameter",
"At least one of magnitude or phase weight must be greater than 0",
)));
}
if self.log_compression < 0.0 {
return Err(AudioSampleError::Parameter(ParameterError::invalid_value(
"parameter",
"Log compression factor must be non-negative",
)));
}
Ok(())
}
}
impl Default for ComplexOnsetConfig {
#[inline]
fn default() -> Self {
Self {
cqt_config: CqtParams::onset_detection(),
hop_size: crate::nzu!(512),
window_size: None,
peak_picking: PeakPickingConfig::default(),
magnitude_weight: 0.7,
phase_weight: 0.3,
magnitude_rectify: true,
phase_rectify: true,
log_compression: 100.0,
}
}
}
impl<T> AudioOnsetDetection for AudioSamples<'_, T>
where
T: StandardSample,
Self: AudioTypeConversion<Sample = T>,
{
#[inline]
fn detect_onsets(&self, config: &OnsetDetectionConfig) -> AudioSampleResult<Vec<f64>> {
let sample_rate = self.sample_rate_hz();
let (_times, odf) = self.onset_detection_function(config)?;
let peaks = pick_peaks(&odf, &config.peak_picking)?;
Ok(peaks
.into_iter()
.map(|idx| config.frame_to_seconds(idx, sample_rate))
.collect())
}
#[inline]
fn onset_detection_function(
&self,
config: &OnsetDetectionConfig,
) -> AudioSampleResult<(NonEmptyVec<f64>, NonEmptyVec<f64>)> {
let sample_rate = self.sample_rate_hz();
let window_size = config.effective_window_size(sample_rate);
let cqt_params = &config.cqt_params;
let stft_params = StftParams::builder()
.n_fft(window_size)
.hop_size(config.hop_size)
.window(cqt_params.window.clone())
.centre(true)
.build()?;
let spectrogram_params = SpectrogramParams::new(stft_params, sample_rate)?;
let mag = self.cqt_magnitude_spectrogram(&spectrogram_params, cqt_params)?;
if mag.dim().1 < 2 {
return Ok((non_empty_vec![0.0], non_empty_vec![0.0]));
}
let mut odf = energy_odf(&mag);
if config.adaptive_threshold {
let median = median_filter(&odf, config.median_filter_length)?;
apply_adaptive_threshold(&mut odf, &median, config.adaptive_threshold_multiplier);
}
let time_frames = mag.times().to_non_empty_vec();
Ok((time_frames, odf))
}
#[inline]
fn detect_onsets_spectral_flux(
&self,
config: &SpectralFluxConfig,
) -> AudioSampleResult<Vec<f64>> {
let sample_rate = self.sample_rate_hz();
let window_size = config.window_size.unwrap_or_else(|| {
let min_period = sample_rate / config.cqt_params.f_min;
unsafe { NonZeroUsize::new_unchecked((min_period * 4.0) as usize) }
});
let (_times, mut flux) = self.spectral_flux(
&config.cqt_params,
window_size,
config.hop_size,
config.flux_method,
)?;
if config.rectify {
rectify_inplace(&mut flux);
}
if config.log_compression > 0.0 {
log_compress_inplace(&mut flux, config.log_compression);
}
let peaks = pick_peaks(&flux, &config.peak_picking)?;
Ok(peaks
.into_iter()
.map(|idx| (idx as f64 * config.hop_size.get() as f64) / sample_rate)
.collect())
}
#[inline]
fn spectral_flux(
&self,
config: &CqtParams,
window_size: NonZeroUsize,
hop_size: NonZeroUsize,
method: SpectralFluxMethod,
) -> AudioSampleResult<(NonEmptyVec<f64>, NonEmptyVec<f64>)> {
let sample_rate = self.sample_rate_hz();
let stft_params = StftParams::builder()
.n_fft(window_size)
.hop_size(hop_size)
.window(config.window.clone())
.centre(true)
.build()?;
let spectrogram_params = SpectrogramParams::new(stft_params, sample_rate)?;
let (times, flux) = match method {
SpectralFluxMethod::Energy => {
let mag = self.cqt_magnitude_spectrogram(&spectrogram_params, config)?;
(mag.times().to_non_empty_vec(), energy_flux(&mag))
}
SpectralFluxMethod::Magnitude => {
let mag = self.cqt_magnitude_spectrogram(&spectrogram_params, config)?;
(mag.times().to_non_empty_vec(), magnitude_flux(&mag))
}
SpectralFluxMethod::Complex => {
let cqt_result = self.constant_q_transform(config, hop_size)?;
let n_frames = cqt_result.n_frames().get();
let times = unsafe {
NonEmptyVec::new_unchecked(
(0..n_frames)
.map(|i| i as f64 * hop_size.get() as f64 / sample_rate)
.collect(),
)
};
(times, complex_flux(&cqt_result.data))
}
SpectralFluxMethod::RectifiedComplex => {
let cqt_result = self.constant_q_transform(config, hop_size)?;
let n_frames = cqt_result.n_frames().get();
let times = unsafe {
NonEmptyVec::new_unchecked(
(0..n_frames)
.map(|i| i as f64 * hop_size.get() as f64 / sample_rate)
.collect(),
)
};
(times, rectified_complex_flux(&cqt_result.data))
}
};
Ok((times, flux))
}
#[inline]
fn complex_onset_detection(
&self,
onset_config: &ComplexOnsetConfig,
) -> AudioSampleResult<Vec<f64>> {
let sample_rate = self.sample_rate_hz();
let odf = self.onset_detection_function_complex(onset_config)?;
let peaks = pick_peaks(&odf, &onset_config.peak_picking)?;
Ok(peaks
.into_iter()
.map(|idx| (idx as f64 * onset_config.hop_size.get() as f64) / sample_rate)
.collect())
}
#[inline]
fn onset_detection_function_complex(
&self,
onset_config: &ComplexOnsetConfig,
) -> AudioSampleResult<NonEmptyVec<f64>> {
let mag_diff = self.magnitude_difference_matrix(onset_config)?;
let phase_dev = self.phase_deviation_matrix(onset_config)?;
Ok(combine_complex_odf(&mag_diff, &phase_dev, onset_config))
}
#[inline]
fn magnitude_difference_matrix(
&self,
config: &ComplexOnsetConfig,
) -> AudioSampleResult<Array2<f64>> {
let cqt_result = self.constant_q_transform(&config.cqt_config, config.hop_size)?;
let mag = cqt_result.to_magnitude();
Ok(magnitude_difference(mag.view()))
}
#[inline]
fn phase_deviation_matrix(
&self,
config: &ComplexOnsetConfig,
) -> AudioSampleResult<ndarray::Array2<f64>> {
let sample_rate = self.sample_rate_hz();
let cqt_result = self.constant_q_transform(&config.cqt_config, config.hop_size)?;
Ok(phase_deviation(cqt_result.data.view(), config, sample_rate))
}
}