use std::num::NonZeroUsize;
use non_empty_iter::{IntoNonEmptyIterator, NonEmptyIterator};
use non_empty_slice::{NonEmptySlice, NonEmptyVec};
use crate::operations::types::{
AdaptiveThresholdConfig, AdaptiveThresholdMethod, NormalizationMethod, PeakPickingConfig,
};
use crate::{AudioSampleError, AudioSampleResult, ParameterError};
#[inline]
pub fn adaptive_threshold(
onset_strength: &NonEmptySlice<f64>,
config: &AdaptiveThresholdConfig,
) -> AudioSampleResult<Vec<f64>> {
config.validate()?;
let len = onset_strength.len().get();
let mut thresholds = Vec::with_capacity(len);
let half_window = config.window_size / 2;
for i in 0..len {
let start = i.saturating_sub(half_window);
let end = (i + half_window + 1).min(len);
let window = &onset_strength[start..end];
let window = non_empty_slice::non_empty_slice!(window);
let threshold = match config.method {
AdaptiveThresholdMethod::Delta => {
let local_max = window.iter().fold(f64::NEG_INFINITY, |acc, &x| acc.max(x));
local_max - config.delta
}
AdaptiveThresholdMethod::Percentile => {
percentile(window, config.percentile)
}
AdaptiveThresholdMethod::Combined => {
let local_max = window.iter().fold(f64::NEG_INFINITY, |acc, &x| acc.max(x));
let delta_threshold = local_max - config.delta;
let percentile_threshold = percentile(window, config.percentile);
delta_threshold.max(percentile_threshold)
}
};
let bounded_threshold = threshold
.max(config.min_threshold)
.min(config.max_threshold);
thresholds.push(bounded_threshold);
}
Ok(thresholds)
}
#[inline]
pub fn pick_peaks(
onset_strength: &NonEmptySlice<f64>,
config: &PeakPickingConfig,
) -> AudioSampleResult<Vec<usize>> {
config.validate().map_err(|e| {
AudioSampleError::Parameter(ParameterError::invalid_value(
"peak_picking_config",
format!("Invalid peak picking config: {e}"),
))
})?;
let mut adjusted_config = *config;
let signal_len = onset_strength.len().get();
if adjusted_config.adaptive_threshold.window_size > signal_len {
adjusted_config.adaptive_threshold.window_size = signal_len.max(3); }
if adjusted_config.min_peak_separation.get() >= signal_len / 2 {
adjusted_config.min_peak_separation =
unsafe { NonZeroUsize::new_unchecked((signal_len / 4).max(1)) };
}
let mut processed_strength = onset_strength.to_non_empty_vec();
if config.pre_emphasis {
processed_strength = apply_pre_emphasis(&processed_strength, config.pre_emphasis_coeff)?;
}
if config.median_filter {
processed_strength = apply_median_filter(&processed_strength, config.median_filter_length)?;
}
if config.normalize_onset_strength {
processed_strength =
normalize_onset_strength(&processed_strength, config.normalization_method);
}
let thresholds = adaptive_threshold(&processed_strength, &config.adaptive_threshold)?;
let mut candidates = Vec::new();
for i in 1..processed_strength.len().get() - 1 {
let current = processed_strength[i];
let prev = processed_strength[i - 1];
let next = processed_strength[i + 1];
if current > prev && current > next && current > thresholds[i] {
candidates.push((i, current));
}
}
if candidates.is_empty() {
return Ok(Vec::new());
}
let candidates = unsafe { NonEmptyVec::new_unchecked(candidates) };
let peaks = apply_temporal_constraints(&candidates, adjusted_config.min_peak_separation.get());
Ok(peaks)
}
#[inline]
pub fn apply_pre_emphasis(
signal: &NonEmptySlice<f64>,
coeff: f64,
) -> AudioSampleResult<NonEmptyVec<f64>> {
if !(0.0..=1.0).contains(&coeff) {
return Err(AudioSampleError::Parameter(ParameterError::invalid_value(
"pre_emphasis_coefficient",
"Pre-emphasis coefficient must be between 0.0 and 1.0",
)));
}
let mut filtered = Vec::with_capacity(signal.len().get());
filtered.push(signal[0]);
for i in 1..signal.len().get() {
filtered.push(coeff.mul_add(-signal[i - 1], signal[i]));
}
let filtered = unsafe { NonEmptyVec::new_unchecked(filtered) };
Ok(filtered)
}
#[inline]
pub fn apply_median_filter(
signal: &NonEmptySlice<f64>,
filter_length: NonZeroUsize,
) -> AudioSampleResult<NonEmptyVec<f64>> {
if filter_length.get().is_multiple_of(2) {
return Err(AudioSampleError::Parameter(ParameterError::invalid_value(
"filter_length",
"Median filter length must be odd and greater than 0",
)));
}
if filter_length.get() == 1 {
return Ok(signal.to_owned());
}
let mut filtered = Vec::with_capacity(signal.len().get());
let half_length = filter_length.get() / 2;
for i in 0..signal.len().get() {
let start = i.saturating_sub(half_length);
let end = (i + half_length + 1).min(signal.len().get());
let mut window: Vec<f64> = signal[start..end].to_vec();
window.sort_by(|a, b| {
a.partial_cmp(b)
.map_or(std::cmp::Ordering::Equal, |order| order)
});
let median = window[window.len() / 2];
filtered.push(median);
}
let filtered = unsafe { NonEmptyVec::new_unchecked(filtered) };
Ok(filtered)
}
#[inline]
#[must_use]
pub fn normalize_onset_strength(
onset_strength: &NonEmptySlice<f64>,
method: NormalizationMethod,
) -> NonEmptyVec<f64> {
match method {
NormalizationMethod::Peak => {
let max_abs = onset_strength
.iter()
.fold(0.0, |acc: f64, &x| acc.max(x.abs()));
if max_abs == 0.0 {
return onset_strength.to_owned();
}
onset_strength
.into_non_empty_iter()
.map(|&x| x / max_abs)
.collect_non_empty()
}
NormalizationMethod::MinMax => {
let min_val = onset_strength
.iter()
.fold(f64::INFINITY, |acc, &x| acc.min(x));
let max_val = onset_strength
.iter()
.fold(f64::NEG_INFINITY, |acc, &x| acc.max(x));
if (max_val - min_val).abs() < f64::EPSILON {
return onset_strength.to_owned();
}
onset_strength
.into_non_empty_iter()
.map(|&x| (x - min_val) / (max_val - min_val))
.collect_non_empty()
}
NormalizationMethod::ZScore => {
let mean = onset_strength.iter().fold(0.0, |acc, &x| acc + x)
/ onset_strength.len().get() as f64;
let variance = onset_strength
.iter()
.map(|&x| (x - mean).powi(2))
.fold(0.0, |acc, x| acc + x)
/ onset_strength.len().get() as f64;
if variance == 0.0 {
return onset_strength.to_owned();
}
let std_dev = variance.sqrt();
onset_strength
.into_non_empty_iter()
.map(|&x| (x - mean) / std_dev)
.collect_non_empty()
}
NormalizationMethod::Mean => {
let mean = onset_strength.iter().fold(0.0, |acc, &x| acc + x)
/ onset_strength.len().get() as f64;
onset_strength
.into_non_empty_iter()
.map(|&x| x - mean)
.collect_non_empty()
}
NormalizationMethod::Median => {
let mut sorted = onset_strength.to_vec();
sorted.sort_by(|a, b| {
a.partial_cmp(b)
.map_or(std::cmp::Ordering::Equal, |order| order)
});
let median = sorted[sorted.len() / 2];
onset_strength
.into_non_empty_iter()
.map(|&x| x - median)
.collect_non_empty()
}
}
}
fn apply_temporal_constraints(
candidates: &NonEmptySlice<(usize, f64)>,
min_separation: usize,
) -> Vec<usize> {
let mut sorted_candidates = candidates.to_vec();
sorted_candidates.sort_by(|a, b| {
b.1.partial_cmp(&a.1)
.map_or(std::cmp::Ordering::Equal, |order| order)
});
let mut selected_peaks = Vec::new();
for &(index, _strength) in &sorted_candidates {
let mut valid = true;
for &selected_index in &selected_peaks {
if index.abs_diff(selected_index) < min_separation {
valid = false;
break;
}
}
if valid {
selected_peaks.push(index);
}
}
selected_peaks.sort_unstable();
selected_peaks
}
#[inline]
pub fn smooth_onset_strength(
onset_strength: &NonEmptySlice<f64>,
window_size: NonZeroUsize,
median_length: NonZeroUsize,
) -> AudioSampleResult<NonEmptyVec<f64>> {
let mut smoothed: NonEmptyVec<f64> = apply_median_filter(onset_strength, median_length)?;
if window_size.get() > 1 {
smoothed = apply_moving_average(&smoothed, window_size);
}
Ok(smoothed)
}
fn apply_moving_average(
signal: &NonEmptySlice<f64>,
window_size: NonZeroUsize,
) -> NonEmptyVec<f64> {
if window_size.get() == 1 {
return signal.to_owned();
}
let mut smoothed = Vec::with_capacity(signal.len().get());
let half_window = window_size.get() / 2;
for i in 0..signal.len().get() {
let start = i.saturating_sub(half_window);
let end = (i + half_window + 1).min(signal.len().get());
let sum: f64 = signal[start..end].iter().sum();
let average = sum / (end - start) as f64;
smoothed.push(average);
}
unsafe { NonEmptyVec::new_unchecked(smoothed) }
}
fn percentile(values: &NonEmptySlice<f64>, percentile: f64) -> f64 {
let mut sorted = values.to_vec();
sorted.sort_by(|a, b| {
a.partial_cmp(b)
.map_or(std::cmp::Ordering::Equal, |order| order)
});
let n = sorted.len();
if n == 1 {
return sorted[0];
}
let index = percentile * (n - 1) as f64;
let lower = index.floor() as usize;
let upper = index.ceil() as usize;
if lower == upper {
sorted[lower]
} else {
let weight = index - lower as f64;
sorted[lower].mul_add(1.0 - weight, sorted[upper] * weight)
}
}
#[cfg(test)]
mod tests {
use non_empty_slice::{NonEmptyVec, non_empty_vec};
use super::*;
use crate::operations::types::NormalizationMethod;
#[test]
fn test_adaptive_threshold_delta() {
let onset_strength =
NonEmptySlice::from_slice(&[0.1, 0.3, 0.8, 0.2, 0.4, 0.9, 0.1]).unwrap();
let config = AdaptiveThresholdConfig::delta(0.1, 3);
let thresholds = adaptive_threshold(&onset_strength, &config).unwrap();
assert_eq!(thresholds.len(), onset_strength.len().get());
for &threshold in &thresholds {
assert!(threshold >= config.min_threshold);
}
}
#[test]
fn test_adaptive_threshold_percentile() {
let onset_strength = non_empty_vec![0.1, 0.3, 0.8, 0.2, 0.4, 0.9, 0.1];
let config = AdaptiveThresholdConfig::percentile(0.8, 3);
let thresholds = adaptive_threshold(&onset_strength, &config).unwrap();
assert_eq!(thresholds.len(), onset_strength.len().get());
for &threshold in &thresholds {
assert!(threshold >= config.min_threshold);
assert!(threshold <= config.max_threshold);
}
}
#[test]
fn test_adaptive_threshold_combined() {
let onset_strength = non_empty_vec![0.1, 0.3, 0.8, 0.2, 0.4, 0.9, 0.1];
let config = AdaptiveThresholdConfig::combined(0.1, 0.8, 3);
let thresholds = adaptive_threshold(&onset_strength, &config).unwrap();
assert_eq!(thresholds.len(), onset_strength.len().get());
}
#[test]
fn test_pick_peaks_basic() {
let onset_strength = non_empty_vec![0.1, 0.3, 0.8, 0.2, 0.4, 0.9, 0.1];
let mut config = PeakPickingConfig::default();
config.adaptive_threshold.window_size = 3; config.min_peak_separation = crate::nzu!(1); config.pre_emphasis = false; config.median_filter = false; config.normalize_onset_strength = false;
let peaks = pick_peaks(&onset_strength, &config).unwrap();
assert!(!peaks.is_empty());
for &peak in &peaks {
assert!(peak < onset_strength.len().get());
}
}
#[test]
fn test_pick_peaks_with_constraints() {
let onset_strength = non_empty_vec![0.1, 0.5, 0.6, 0.7, 0.2, 0.8, 0.1];
let mut config = PeakPickingConfig::default();
config.min_peak_separation = crate::nzu!(3);
let peaks = pick_peaks(&onset_strength, &config).unwrap();
for i in 1..peaks.len() {
assert!(peaks[i] - peaks[i - 1] >= config.min_peak_separation.get());
}
}
#[test]
fn test_pre_emphasis() {
let signal: NonEmptyVec<f64> = non_empty_vec![1.0, 2.0, 3.0, 2.0, 1.0];
let coeff: f64 = 0.97;
let filtered = apply_pre_emphasis(&signal, coeff).unwrap();
assert_eq!(filtered.len(), signal.len());
assert_eq!(filtered[0], signal[0]);
for i in 1..signal.len().get() {
let expected = signal[i] - coeff * signal[i - 1];
assert!((filtered[i] - expected).abs() < 1e-10);
}
}
#[test]
fn test_median_filter() {
let signal: NonEmptyVec<f64> = non_empty_vec![1.0, 5.0, 2.0, 8.0, 3.0]; let filtered = apply_median_filter(&signal, crate::nzu!(3)).unwrap();
assert_eq!(filtered.len(), signal.len());
assert!(filtered[1] < signal[1]); }
#[test]
fn test_normalize_onset_strength_peak() {
let onset_strength: NonEmptyVec<f64> = non_empty_vec![0.1, 0.5, 1.0, 0.3];
let normalized = normalize_onset_strength(&onset_strength, NormalizationMethod::Peak);
assert_eq!(normalized.len(), onset_strength.len());
let max_abs = normalized.iter().fold(0.0f64, |acc, &x| acc.max(x.abs()));
assert!((max_abs - 1.0f64).abs() < 1e-10);
}
#[test]
fn test_normalize_onset_strength_minmax() {
let onset_strength: NonEmptyVec<f64> = non_empty_vec![0.1, 0.5, 1.0, 0.3];
let normalized = normalize_onset_strength(&onset_strength, NormalizationMethod::MinMax);
assert_eq!(normalized.len(), onset_strength.len());
let min_val = normalized.iter().fold(f64::INFINITY, |acc, &x| acc.min(x));
let max_val = normalized
.iter()
.fold(f64::NEG_INFINITY, |acc, &x| acc.max(x));
assert!(min_val >= 0.0);
assert!(max_val <= 1.0);
assert!((max_val - 1.0).abs() < 1e-10);
assert!((min_val - 0.0).abs() < 1e-10);
}
#[test]
fn test_smooth_onset_strength() {
let onset_strength: NonEmptyVec<f64> = non_empty_vec![0.1, 0.9, 0.1, 0.8, 0.2]; let smoothed =
smooth_onset_strength(&onset_strength, crate::nzu!(3), crate::nzu!(3)).unwrap();
assert_eq!(smoothed.len(), onset_strength.len());
let original_std = standard_deviation(&onset_strength);
let smoothed_std = standard_deviation(&smoothed);
assert!(smoothed_std <= original_std);
}
#[test]
fn test_temporal_constraints() {
let candidates = non_empty_vec![(1, 0.8), (2, 0.6), (5, 0.9), (6, 0.7)];
let min_separation = 2;
let selected = apply_temporal_constraints(&candidates, min_separation);
for i in 1..selected.len() {
assert!(selected[i] - selected[i - 1] >= min_separation);
}
}
#[test]
fn test_percentile() {
let values: NonEmptyVec<f64> = non_empty_vec![1.0, 2.0, 3.0, 4.0, 5.0];
assert_eq!(percentile(&values, 0.0), 1.0);
assert_eq!(percentile(&values, 1.0), 5.0);
assert_eq!(percentile(&values, 0.5), 3.0);
let p25 = percentile(&values, 0.25);
assert!(p25 > 1.0 && p25 < 3.0);
}
#[test]
fn test_edge_cases() {
let config = AdaptiveThresholdConfig::default();
let single: NonEmptyVec<f64> = non_empty_vec![1.0];
let thresholds = adaptive_threshold(&single, &config).unwrap();
assert_eq!(thresholds.len(), 1);
let zeros: NonEmptyVec<f64> = non_empty_vec![0.0, 0.0, 0.0];
let normalized = normalize_onset_strength(&zeros, NormalizationMethod::Peak);
assert_eq!(normalized, zeros);
}
#[test]
fn test_config_validation() {
let mut config = AdaptiveThresholdConfig::default();
config.delta = -0.1;
assert!(config.validate().is_err());
config = AdaptiveThresholdConfig::default();
config.percentile = 1.5;
assert!(config.validate().is_err());
config = AdaptiveThresholdConfig::default();
config.window_size = 0;
assert!(config.validate().is_err());
config = AdaptiveThresholdConfig::default();
config.min_threshold = 0.5;
config.max_threshold = 0.3;
assert!(config.validate().is_err());
}
#[test]
fn test_peak_picking_presets() {
let onset_strength = non_empty_vec![0.1, 0.3, 0.8, 0.2, 0.4, 0.9, 0.1];
let configs = vec![
PeakPickingConfig::default(),
PeakPickingConfig::music(),
PeakPickingConfig::speech(),
PeakPickingConfig::drums(),
];
for config in configs {
let peaks = pick_peaks(&onset_strength, &config).unwrap();
for &peak in &peaks {
assert!(peak < onset_strength.len().get());
}
}
}
fn standard_deviation(values: &[f64]) -> f64 {
let mean = values.iter().sum::<f64>() / values.len() as f64;
let variance =
values.iter().map(|&x| (x - mean).powi(2)).sum::<f64>() / values.len() as f64;
variance.sqrt()
}
}