use std::num::{NonZeroU32, NonZeroUsize};
use crate::operations::traits::{AudioDecomposition, AudioTransforms};
use crate::traits::StandardSample;
use crate::{
AudioSampleError, AudioSampleResult, AudioSamples, AudioTypeConversion, ParameterError,
};
use ndarray::{Array2, s};
use num_complex::Complex;
use spectrograms::{StftParams, StftParamsBuilder, istft as spectrograms_istft};
#[derive(Debug, Clone, PartialEq)]
#[non_exhaustive]
pub struct HpssConfig {
pub stft_params: StftParams,
pub median_filter_harmonic: usize,
pub median_filter_percussive: usize,
pub mask_softness: f64,
}
impl HpssConfig {
#[inline]
#[must_use]
pub const fn new(
stft_params: StftParams,
median_filter_harmonic: usize,
median_filter_percussive: usize,
mask_softness: f64,
) -> Self {
Self {
stft_params,
median_filter_harmonic,
median_filter_percussive,
mask_softness,
}
}
#[inline]
#[must_use]
pub fn musical() -> Self {
let stft_params = StftParamsBuilder::default()
.n_fft(crate::nzu!(2048))
.hop_size(crate::nzu!(512))
.build()
.expect("All parameters are set and valid according to the builder");
Self {
stft_params,
median_filter_harmonic: 31,
median_filter_percussive: 31,
mask_softness: 0.5,
}
}
#[inline]
#[must_use]
pub fn percussive() -> Self {
let stft_params = StftParamsBuilder::default()
.n_fft(crate::nzu!(2048))
.hop_size(crate::nzu!(256))
.build()
.expect("All parameters are set and valid according to the builder");
Self {
stft_params,
median_filter_harmonic: 17,
median_filter_percussive: 51, mask_softness: 0.1, }
}
#[inline]
#[must_use]
pub fn harmonic() -> Self {
let stft_params = StftParamsBuilder::default()
.n_fft(crate::nzu!(4096))
.hop_size(crate::nzu!(512))
.build()
.expect("All parameters are set and valid according to the builder");
Self {
stft_params,
median_filter_harmonic: 51, median_filter_percussive: 17,
mask_softness: 0.1, }
}
#[inline]
#[must_use]
pub fn realtime() -> Self {
let stft_params = StftParamsBuilder::default()
.n_fft(crate::nzu!(1024))
.hop_size(crate::nzu!(256))
.build()
.expect("All parameters are set and valid according to the builder");
Self {
stft_params,
median_filter_harmonic: 11,
median_filter_percussive: 11,
mask_softness: 0.3,
}
}
#[inline]
pub fn set_stft_params(&mut self, n_fft: NonZeroUsize, hop_size: NonZeroUsize) {
self.stft_params = StftParamsBuilder::default()
.n_fft(n_fft)
.hop_size(hop_size)
.build()
.expect("All parameters are set and valid according to the builder");
}
#[inline]
pub const fn set_filter_sizes(&mut self, harmonic: usize, percussive: usize) {
self.median_filter_harmonic = harmonic;
self.median_filter_percussive = percussive;
}
#[inline]
pub const fn set_mask_softness(&mut self, softness: f64) {
self.mask_softness = softness.clamp(0.0, 1.0);
}
#[inline]
pub fn validate(&self, sample_rate: f64) -> AudioSampleResult<()> {
if !self.stft_params.n_fft().is_power_of_two() {
return Err(AudioSampleError::Parameter(ParameterError::invalid_value(
"win_size",
"Window size must be a positive power of 2",
)));
}
if self.stft_params.hop_size() > self.stft_params.n_fft() {
return Err(AudioSampleError::Parameter(ParameterError::invalid_value(
"hop_size",
"Hop size must be positive and not larger than window size",
)));
}
if self.median_filter_harmonic == 0 {
return Err(AudioSampleError::Parameter(ParameterError::invalid_value(
"median_filter_harmonic",
"Harmonic median filter size must be greater than 0",
)));
}
if self.median_filter_percussive == 0 {
return Err(AudioSampleError::Parameter(ParameterError::invalid_value(
"median_filter_percussive",
"Percussive median filter size must be greater than 0",
)));
}
if self.mask_softness < 0.0 || self.mask_softness > 1.0 {
return Err(AudioSampleError::Parameter(ParameterError::invalid_value(
"mask_softness",
"Mask softness must be between 0.0 and 1.0",
)));
}
if self.stft_params.n_fft() > crate::nzu!(163_840) {
return Err(AudioSampleError::Parameter(ParameterError::invalid_value(
"win_size",
"Window size should not exceed 16384 samples for practical processing",
)));
}
if self.median_filter_harmonic > 101 || self.median_filter_percussive > 101 {
return Err(AudioSampleError::Parameter(ParameterError::invalid_value(
"median_filter_size",
"Median filter sizes should not exceed 101 for practical processing",
)));
}
let freq_resolution = self.freq_resolution(sample_rate);
if freq_resolution > 50.0 {
return Err(AudioSampleError::Parameter(ParameterError::invalid_value(
"win_size",
format!(
"Window too small, frequency resolution ({freq_resolution:.1} Hz) is too low"
),
)));
}
Ok(())
}
#[inline]
#[must_use]
pub const fn num_freq_bins(&self) -> NonZeroUsize {
self.stft_params
.n_fft()
.div_ceil(crate::nzu!(2))
.checked_add(1)
.expect("Div 2 plus 1 will get nowhere near the max value")
}
#[inline]
#[must_use]
pub fn freq_resolution(&self, sample_rate: f64) -> f64 {
sample_rate / self.stft_params.n_fft().get() as f64
}
#[inline]
#[must_use]
pub fn time_resolution(&self, sample_rate: f64) -> f64 {
self.stft_params.hop_size().get() as f64 / sample_rate
}
}
impl Default for HpssConfig {
#[inline]
fn default() -> Self {
let stft_params = StftParamsBuilder::default()
.n_fft(crate::nzu!(2048))
.hop_size(crate::nzu!(512))
.build()
.expect("All parameters are set and valid according to the builder");
Self {
stft_params,
median_filter_harmonic: 17,
median_filter_percussive: 17,
mask_softness: 0.3,
}
}
}
impl<T> AudioDecomposition for AudioSamples<'_, T>
where
T: StandardSample,
Self: AudioTypeConversion<Sample = T>,
{
#[inline]
fn hpss(
&self,
config: &HpssConfig,
) -> AudioSampleResult<(
AudioSamples<'static, Self::Sample>,
AudioSamples<'static, Self::Sample>,
)> {
config.validate(self.sample_rate_hz())?;
let min_length = config.stft_params.n_fft();
if self.samples_per_channel() < min_length {
return Err(AudioSampleError::Parameter(ParameterError::invalid_value(
"signal_length",
format!(
"Signal too short ({} samples), need at least {} samples for window size",
self.samples_per_channel(),
min_length
),
)));
}
let (harmonic_audio, percussive_audio) = perform_hpss(self, config)?;
Ok((harmonic_audio, percussive_audio))
}
}
fn perform_hpss<T>(
audio: &AudioSamples<'_, T>,
config: &HpssConfig,
) -> AudioSampleResult<(AudioSamples<'static, T>, AudioSamples<'static, T>)>
where
T: StandardSample,
{
let stft_result = audio.stft(&config.stft_params)?;
let magnitude_spec = stft_result.norm();
let harmonic_spec = median_filter_time_axis(&magnitude_spec, config.median_filter_harmonic);
let percussive_spec = median_filter_freq_axis(&magnitude_spec, config.median_filter_percussive);
let (harmonic_mask, percussive_mask) =
generate_separation_masks(&harmonic_spec, &percussive_spec, config.mask_softness);
let harmonic_stft = apply_mask_to_stft(&stft_result, &harmonic_mask);
let percussive_stft = apply_mask_to_stft(&stft_result, &percussive_mask);
let n_fft = stft_result.params.n_fft();
let hop_size = stft_result.params.hop_size();
let window = stft_result.params.window();
let centre = stft_result.params.centre();
let sample_rate = unsafe { NonZeroU32::new_unchecked(stft_result.sample_rate as u32) };
let harmonic_samples =
spectrograms_istft(&harmonic_stft, n_fft, hop_size, window.clone(), centre)?;
let harmonic_audio: AudioSamples<'static, T> =
AudioSamples::from_mono_vec::<f64>(harmonic_samples, sample_rate);
let percussive_samples = spectrograms_istft(&percussive_stft, n_fft, hop_size, window, centre)?;
let percussive_audio: AudioSamples<'static, T> =
AudioSamples::from_mono_vec::<f64>(percussive_samples, sample_rate);
Ok((harmonic_audio, percussive_audio))
}
fn median_filter_time_axis(spectrogram: &Array2<f64>, kernel_size: usize) -> Array2<f64> {
let (n_freq_bins, n_time_frames) = spectrogram.dim();
let mut filtered = Array2::zeros((n_freq_bins, n_time_frames));
for freq_idx in 0..n_freq_bins {
let freq_row = spectrogram.slice(s![freq_idx, ..]);
let filtered_row = median_filter_1d(&freq_row.to_vec(), kernel_size);
for (time_idx, &val) in filtered_row.iter().enumerate() {
filtered[[freq_idx, time_idx]] = val;
}
}
filtered
}
fn median_filter_freq_axis(spectrogram: &Array2<f64>, kernel_size: usize) -> Array2<f64> {
let (n_freq_bins, n_time_frames) = spectrogram.dim();
let mut filtered = Array2::zeros((n_freq_bins, n_time_frames));
for time_idx in 0..n_time_frames {
let time_col = spectrogram.slice(s![.., time_idx]);
let filtered_col = median_filter_1d(&time_col.to_vec(), kernel_size);
for (freq_idx, &val) in filtered_col.iter().enumerate() {
filtered[[freq_idx, time_idx]] = val;
}
}
filtered
}
fn median_filter_1d(signal: &[f64], kernel_size: usize) -> Vec<f64> {
if kernel_size == 0 {
return signal.to_vec();
}
if kernel_size == 1 {
return signal.to_vec();
}
let len = signal.len();
if len == 0 {
return Vec::new();
}
let half_kernel = kernel_size / 2;
let mut filtered = Vec::with_capacity(len);
for i in 0..len {
let mut window = Vec::with_capacity(kernel_size);
for j in 0..kernel_size {
let idx = i as i32 + j as i32 - half_kernel as i32;
let reflected_idx = if idx < 0 {
(-idx) as usize
} else if idx >= len as i32 {
len - 2 - (idx - len as i32) as usize
} else {
idx as usize
};
let safe_idx = reflected_idx.min(len - 1);
window.push(signal[safe_idx]);
}
window.sort_by(f64::total_cmp);
let median = if kernel_size % 2 == 1 {
window[kernel_size / 2]
} else {
f64::midpoint(window[kernel_size / 2 - 1], window[kernel_size / 2])
};
filtered.push(median);
}
filtered
}
fn generate_separation_masks(
harmonic_spec: &Array2<f64>,
percussive_spec: &Array2<f64>,
mask_softness: f64,
) -> (Array2<f64>, Array2<f64>) {
let (n_freq_bins, n_time_frames) = harmonic_spec.dim();
let mut harmonic_mask = Array2::zeros((n_freq_bins, n_time_frames));
let mut percussive_mask = Array2::zeros((n_freq_bins, n_time_frames));
let epsilon = 1e-10;
for freq_idx in 0..n_freq_bins {
for time_idx in 0..n_time_frames {
let h_val = harmonic_spec[[freq_idx, time_idx]];
let p_val = percussive_spec[[freq_idx, time_idx]];
let total = h_val + p_val + epsilon;
if mask_softness == 0.0 {
if h_val >= p_val {
harmonic_mask[[freq_idx, time_idx]] = 1.0;
percussive_mask[[freq_idx, time_idx]] = 0.0;
} else {
harmonic_mask[[freq_idx, time_idx]] = 0.0;
percussive_mask[[freq_idx, time_idx]] = 1.0;
}
} else {
let h_ratio = h_val / total;
let p_ratio = p_val / total;
let h_soft = mask_softness.mul_add(
h_ratio,
(1.0 - mask_softness) * if h_val >= p_val { 1.0 } else { 0.0 },
);
let p_soft = mask_softness.mul_add(
p_ratio,
(1.0 - mask_softness) * if p_val > h_val { 1.0 } else { 0.0 },
);
harmonic_mask[[freq_idx, time_idx]] = h_soft;
percussive_mask[[freq_idx, time_idx]] = p_soft;
}
}
}
(harmonic_mask, percussive_mask)
}
fn apply_mask_to_stft<A: AsRef<Array2<Complex<f64>>>>(
stft: A,
mask: &Array2<f64>,
) -> Array2<Complex<f64>> {
let stft = stft.as_ref();
let (n_freq_bins, n_time_frames) = stft.dim();
let mut masked_stft = Array2::zeros((n_freq_bins, n_time_frames));
for freq_idx in 0..n_freq_bins {
for time_idx in 0..n_time_frames {
let original = stft[[freq_idx, time_idx]];
let mask_val = mask[[freq_idx, time_idx]];
masked_stft[[freq_idx, time_idx]] = original * mask_val;
}
}
masked_stft
}
#[cfg(test)]
mod tests {
use super::*;
use crate::utils::generation::sine_wave;
#[test]
fn test_median_filter_1d() {
let signal = vec![1.0, 5.0, 3.0, 2.0, 4.0];
let filtered = median_filter_1d(&signal, 3);
assert_eq!(filtered.len(), signal.len());
assert!((filtered[2] - 3.0f64).abs() < 1e-10);
}
#[test]
fn test_hpss_basic_functionality() {
let sample_rate = crate::sample_rate!(8000); let duration = std::time::Duration::from_millis(200);
let sine_audio = sine_wave::<f32>(440.0, duration, sample_rate, 0.5);
let config = HpssConfig::realtime();
let (harmonic, percussive) = sine_audio.hpss(&config).unwrap();
let original_length = sine_audio.samples_per_channel().get();
assert!(harmonic.samples_per_channel().get() > 0);
assert!(percussive.samples_per_channel().get() > 0);
let length_diff =
(harmonic.samples_per_channel().get() as i32 - original_length as i32).abs();
assert!(
length_diff < 1000,
"Length difference too large: {}",
length_diff
);
assert_eq!(harmonic.sample_rate, sine_audio.sample_rate);
assert_eq!(percussive.sample_rate, sine_audio.sample_rate);
}
#[test]
fn test_hpss_config_validation() {
let config = HpssConfig::default();
let sample_rate = 44100.0;
assert!(config.validate(sample_rate).is_ok());
let config = HpssConfig {
stft_params: spectrograms::StftParamsBuilder::default()
.n_fft(crate::nzu!(1000))
.hop_size(crate::nzu!(256))
.build()
.unwrap(),
..HpssConfig::default()
};
assert!(config.validate(sample_rate).is_err());
let stft_params_result = spectrograms::StftParamsBuilder::default()
.n_fft(crate::nzu!(1024))
.hop_size(crate::nzu!(2048))
.build();
assert!(stft_params_result.is_err());
let config = HpssConfig::default();
let config = HpssConfig {
mask_softness: 1.5,
..config
};
assert!(config.validate(sample_rate).is_err());
}
#[test]
fn test_separation_masks() {
let harmonic_spec = Array2::from_shape_vec(
(2, 3),
vec![
1.0, 2.0, 3.0, 0.5, 1.0, 1.5, ],
)
.unwrap();
let percussive_spec = Array2::from_shape_vec(
(2, 3),
vec![
0.5, 0.5, 0.5, 2.0, 1.5, 1.0, ],
)
.unwrap();
let (h_mask, p_mask) = generate_separation_masks(&harmonic_spec, &percussive_spec, 0.0);
for i in 0..2 {
for j in 0..3 {
let h_val = h_mask[[i, j]];
let p_val = p_mask[[i, j]];
assert!(h_val == 0.0 || h_val == 1.0);
assert!(p_val == 0.0 || p_val == 1.0);
assert!((h_val + p_val - 1.0f64).abs() < 1e-10 || (h_val + p_val).abs() < 1e-10);
}
}
}
}