#![allow(dead_code)]
use oxifft::Complex;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum StemType {
Vocals,
Drums,
Bass,
Other,
}
impl StemType {
#[must_use]
pub fn label(self) -> &'static str {
match self {
Self::Vocals => "Vocals",
Self::Drums => "Drums",
Self::Bass => "Bass",
Self::Other => "Other",
}
}
#[must_use]
pub fn all() -> [Self; 4] {
[Self::Vocals, Self::Drums, Self::Bass, Self::Other]
}
}
#[derive(Debug, Clone)]
pub struct SeparationConfig {
pub stems: Vec<StemType>,
pub sample_rate: f32,
pub window_size: usize,
pub hop_size: usize,
pub quality: f32,
}
impl Default for SeparationConfig {
fn default() -> Self {
Self {
stems: StemType::all().to_vec(),
sample_rate: 44100.0,
window_size: 4096,
hop_size: 1024,
quality: 0.8,
}
}
}
impl SeparationConfig {
#[must_use]
pub fn is_valid(&self) -> bool {
!self.stems.is_empty()
&& self.sample_rate > 0.0
&& self.window_size >= 2
&& self.hop_size >= 1
&& self.hop_size <= self.window_size
&& self.quality >= 0.0
&& self.quality <= 1.0
}
}
#[derive(Debug, Clone)]
pub struct Stem {
pub stem_type: StemType,
pub samples: Vec<f32>,
pub energy_ratio: f32,
}
impl Stem {
#[must_use]
pub fn new(stem_type: StemType, samples: Vec<f32>, energy_ratio: f32) -> Self {
Self {
stem_type,
samples,
energy_ratio,
}
}
#[allow(clippy::cast_precision_loss)]
#[must_use]
pub fn rms(&self) -> f32 {
if self.samples.is_empty() {
return 0.0;
}
let sum: f32 = self.samples.iter().map(|s| s * s).sum();
(sum / self.samples.len() as f32).sqrt()
}
}
#[derive(Debug, Clone)]
pub struct SeparationResult {
pub stems: Vec<Stem>,
pub sample_rate: f32,
pub sdr_estimate_db: f32,
}
impl SeparationResult {
#[must_use]
pub fn new(stems: Vec<Stem>, sample_rate: f32, sdr_estimate_db: f32) -> Self {
Self {
stems,
sample_rate,
sdr_estimate_db,
}
}
#[must_use]
pub fn stem_count(&self) -> usize {
self.stems.len()
}
#[must_use]
pub fn get_stem(&self, stem_type: StemType) -> Option<&Stem> {
self.stems.iter().find(|s| s.stem_type == stem_type)
}
#[must_use]
pub fn is_acceptable_quality(&self) -> bool {
self.sdr_estimate_db > 6.0
}
}
const NMF_EPSILON: f32 = 1e-10;
const NMF_ITERATIONS: usize = 50;
fn hann_window(n: usize) -> Vec<f32> {
use std::f32::consts::PI;
(0..n)
.map(|i| {
let phase = 2.0 * PI * i as f32 / (n - 1).max(1) as f32;
0.5 * (1.0 - phase.cos())
})
.collect()
}
#[allow(clippy::cast_precision_loss)]
fn stft(signal: &[f32], window_size: usize, hop_size: usize) -> Vec<Vec<Complex<f32>>> {
let window = hann_window(window_size);
let n_bins = window_size / 2 + 1;
let n_frames = if signal.len() < window_size {
1
} else {
(signal.len() - window_size) / hop_size + 1
};
let mut frames = Vec::with_capacity(n_frames);
for frame_idx in 0..n_frames {
let start = frame_idx * hop_size;
let buf: Vec<Complex<f32>> = (0..window_size)
.map(|k| {
let sample_idx = start + k;
let sample = if sample_idx < signal.len() {
signal[sample_idx]
} else {
0.0
};
Complex::new(sample * window[k], 0.0)
})
.collect();
let fft_result = oxifft::fft(&buf);
let frame: Vec<Complex<f32>> = fft_result[..n_bins].to_vec();
frames.push(frame);
}
frames
}
#[allow(clippy::cast_precision_loss)]
fn istft(
frames: &[Vec<Complex<f32>>],
window_size: usize,
hop_size: usize,
n_samples: usize,
) -> Vec<f32> {
let window = hann_window(window_size);
let n_bins = window_size / 2 + 1;
let norm = 1.0 / window_size as f32;
let mut output = vec![0.0f32; n_samples + window_size];
let mut window_sum = vec![0.0f32; n_samples + window_size];
for (frame_idx, frame) in frames.iter().enumerate() {
let mut buf: Vec<Complex<f32>> = vec![Complex::new(0.0, 0.0); window_size];
for (k, &c) in frame.iter().enumerate().take(n_bins) {
buf[k] = c;
if k > 0 && k < window_size - n_bins + 1 {
buf[window_size - k] = c.conj();
}
}
let ifft_result = oxifft::ifft(&buf);
let start = frame_idx * hop_size;
for k in 0..window_size {
let idx = start + k;
if idx < output.len() {
output[idx] += ifft_result[k].re * norm * window[k];
window_sum[idx] += window[k] * window[k];
}
}
}
output
.into_iter()
.zip(window_sum)
.take(n_samples)
.map(|(s, w)| if w > NMF_EPSILON { s / w } else { s })
.collect()
}
fn magnitude_spectrogram(stft_frames: &[Vec<Complex<f32>>]) -> (Vec<f32>, usize, usize) {
let n_frames = stft_frames.len();
let n_freqs = stft_frames.first().map_or(0, |f| f.len());
let mut v = vec![0.0f32; n_freqs * n_frames];
for (t, frame) in stft_frames.iter().enumerate() {
for (f, &c) in frame.iter().enumerate() {
v[f * n_frames + t] = c.norm();
}
}
(v, n_freqs, n_frames)
}
fn lcg_next(state: &mut u64) -> f32 {
*state = state
.wrapping_mul(6_364_136_223_846_793_005)
.wrapping_add(1_442_695_040_888_963_407);
((*state >> 33) as f32) / (u32::MAX as f32) + 1e-4
}
#[allow(clippy::cast_precision_loss, clippy::many_single_char_names)]
fn nmf(v: &[f32], n_freqs: usize, n_frames: usize, n_comp: usize) -> (Vec<f32>, Vec<f32>) {
let mut rng_state: u64 = 0x5EED_CAFE_DEAD_BEEF;
let mut w: Vec<f32> = (0..n_freqs * n_comp)
.map(|_| lcg_next(&mut rng_state))
.collect();
let mut h: Vec<f32> = (0..n_comp * n_frames)
.map(|_| lcg_next(&mut rng_state))
.collect();
let mut wh = vec![0.0f32; n_freqs * n_frames]; let mut wt_v = vec![0.0f32; n_comp * n_frames]; let mut wt_wh = vec![0.0f32; n_comp * n_frames]; let mut v_ht = vec![0.0f32; n_freqs * n_comp]; let mut wh_ht = vec![0.0f32; n_freqs * n_comp];
for _ in 0..NMF_ITERATIONS {
matmul(n_freqs, n_comp, n_frames, &w, &h, &mut wh);
matmul_at_b(n_freqs, n_comp, n_frames, &w, v, &mut wt_v);
matmul_at_b(n_freqs, n_comp, n_frames, &w, &wh, &mut wt_wh);
for i in 0..n_comp * n_frames {
h[i] *= wt_v[i] / (wt_wh[i] + NMF_EPSILON);
}
matmul(n_freqs, n_comp, n_frames, &w, &h, &mut wh);
matmul_a_bt(n_freqs, n_frames, n_comp, v, &h, &mut v_ht);
matmul_a_bt(n_freqs, n_frames, n_comp, &wh, &h, &mut wh_ht);
for i in 0..n_freqs * n_comp {
w[i] *= v_ht[i] / (wh_ht[i] + NMF_EPSILON);
}
}
(w, h)
}
fn matmul(m: usize, k: usize, n: usize, a: &[f32], b: &[f32], c: &mut [f32]) {
for ci in c.iter_mut() {
*ci = 0.0;
}
for i in 0..m {
for p in 0..k {
let a_ip = a[i * k + p];
for j in 0..n {
c[i * n + j] += a_ip * b[p * n + j];
}
}
}
}
fn matmul_at_b(m: usize, k: usize, n: usize, a: &[f32], b: &[f32], c: &mut [f32]) {
for ci in c.iter_mut() {
*ci = 0.0;
}
for i in 0..m {
for p in 0..k {
let a_ip = a[i * k + p]; for j in 0..n {
c[p * n + j] += a_ip * b[i * n + j];
}
}
}
}
fn matmul_a_bt(m: usize, n: usize, k: usize, a: &[f32], b: &[f32], c: &mut [f32]) {
for ci in c.iter_mut() {
*ci = 0.0;
}
for i in 0..m {
for j in 0..k {
let mut sum = 0.0f32;
for p in 0..n {
sum += a[i * n + p] * b[j * n + p]; }
c[i * k + j] = sum;
}
}
}
#[allow(clippy::cast_precision_loss)]
fn spectral_centroid(w: &[f32], n_freqs: usize, n_comp: usize, comp: usize) -> f32 {
let mut numerator = 0.0f32;
let mut denominator = 0.0f32;
for freq in 0..n_freqs {
let weight = w[freq * n_comp + comp];
numerator += freq as f32 * weight;
denominator += weight;
}
if denominator > NMF_EPSILON {
numerator / denominator
} else {
0.0
}
}
#[allow(clippy::cast_precision_loss)]
fn assign_components_to_stems(
w: &[f32],
n_freqs: usize,
n_comp: usize,
requested_stems: &[StemType],
) -> Vec<StemType> {
let n_freqs_f = n_freqs as f32;
let nominal = |st: StemType| -> f32 {
match st {
StemType::Bass => 0.04,
StemType::Drums => 0.16,
StemType::Vocals => 0.42,
StemType::Other => 0.75,
}
};
let centroids: Vec<f32> = (0..n_comp)
.map(|c| spectral_centroid(w, n_freqs, n_comp, c) / n_freqs_f)
.collect();
let mut assignments: Vec<StemType> = centroids
.iter()
.map(|¢roid_norm| {
let preferred = if centroid_norm < 0.08 {
StemType::Bass
} else if centroid_norm < 0.25 {
StemType::Drums
} else if centroid_norm < 0.60 {
StemType::Vocals
} else {
StemType::Other
};
if requested_stems.contains(&preferred) {
preferred
} else {
requested_stems
.iter()
.copied()
.min_by(|&a, &b| {
let da = (nominal(a) - centroid_norm).abs();
let db = (nominal(b) - centroid_norm).abs();
da.partial_cmp(&db).unwrap_or(std::cmp::Ordering::Equal)
})
.unwrap_or(requested_stems[0])
}
})
.collect();
let uncovered: Vec<StemType> = requested_stems
.iter()
.copied()
.filter(|&st| !assignments.iter().any(|&a| a == st))
.collect();
if !uncovered.is_empty() {
let mut reserved: std::collections::HashSet<usize> = std::collections::HashSet::new();
for stem_type in uncovered {
let target = nominal(stem_type);
let best_comp = centroids
.iter()
.enumerate()
.filter(|(idx, _)| !reserved.contains(idx))
.min_by(|(ia, &ca), (ib, &cb)| {
let da = (ca - target).abs();
let db = (cb - target).abs();
da.partial_cmp(&db)
.unwrap_or(std::cmp::Ordering::Equal)
.then(ia.cmp(ib))
})
.map(|(idx, _)| idx)
.unwrap_or(0);
reserved.insert(best_comp);
assignments[best_comp] = stem_type;
}
}
assignments
}
#[allow(clippy::too_many_arguments)]
fn reconstruct_stems(
stft_frames: &[Vec<Complex<f32>>],
w: &[f32],
h: &[f32],
n_freqs: usize,
n_frames: usize,
n_comp: usize,
component_assignments: &[StemType],
requested_stems: &[StemType],
window_size: usize,
hop_size: usize,
n_samples: usize,
) -> Vec<(StemType, Vec<f32>)> {
let mut wh_full = vec![0.0f32; n_freqs * n_frames];
matmul(n_freqs, n_comp, n_frames, w, h, &mut wh_full);
let mut results = Vec::with_capacity(requested_stems.len());
for &stem_type in requested_stems {
let stem_comps: Vec<usize> = component_assignments
.iter()
.enumerate()
.filter(|(_, &st)| st == stem_type)
.map(|(i, _)| i)
.collect();
let mut wh_stem = vec![0.0f32; n_freqs * n_frames];
for &comp in &stem_comps {
for freq in 0..n_freqs {
let w_val = w[freq * n_comp + comp];
for t in 0..n_frames {
wh_stem[freq * n_frames + t] += w_val * h[comp * n_frames + t];
}
}
}
let masked_frames: Vec<Vec<Complex<f32>>> = stft_frames
.iter()
.enumerate()
.map(|(t, frame)| {
frame
.iter()
.enumerate()
.map(|(freq, &c)| {
let mask = wh_stem[freq * n_frames + t]
/ (wh_full[freq * n_frames + t] + NMF_EPSILON);
c * mask
})
.collect()
})
.collect();
let samples = istft(&masked_frames, window_size, hop_size, n_samples);
results.push((stem_type, samples));
}
results
}
#[allow(clippy::cast_precision_loss)]
fn estimate_sdr(mixture: &[f32], stems: &[(StemType, Vec<f32>)]) -> f32 {
if stems.is_empty() || mixture.is_empty() {
return 0.0;
}
let n = mixture.len();
let mut reconstruction = vec![0.0f32; n];
for (_, stem_samples) in stems {
for (i, &s) in stem_samples.iter().enumerate().take(n) {
reconstruction[i] += s;
}
}
let signal_energy: f32 = mixture.iter().map(|s| s * s).sum();
let residual_energy: f32 = mixture
.iter()
.zip(reconstruction.iter())
.map(|(m, r)| (m - r).powi(2))
.sum();
if residual_energy < 1e-20 {
return 60.0; }
if signal_energy < 1e-20 {
return 0.0;
}
10.0 * (signal_energy / residual_energy).log10()
}
pub struct StemSeparator {
config: SeparationConfig,
}
impl StemSeparator {
#[must_use]
pub fn new(config: SeparationConfig) -> Option<Self> {
if config.is_valid() {
Some(Self { config })
} else {
None
}
}
#[allow(clippy::cast_precision_loss)]
#[must_use]
pub fn separate(&self, mixture: &[f32]) -> SeparationResult {
let n_samples = mixture.len();
let window_size = self.config.window_size;
let hop_size = self.config.hop_size;
let n_stems = self.config.stems.len();
let n_comp = {
let base = (2.0 + self.config.quality * 4.0) as usize;
(base * n_stems).max(n_stems * 2)
};
if n_samples < 2 {
let stems = self
.config
.stems
.iter()
.map(|&st| Stem::new(st, mixture.to_vec(), 1.0 / n_stems as f32))
.collect();
return SeparationResult::new(stems, self.config.sample_rate, 0.0);
}
let stft_frames = stft(mixture, window_size, hop_size);
let (mag_spec, n_freqs, n_frames) = magnitude_spectrogram(&stft_frames);
if n_freqs == 0 || n_frames == 0 {
let stems = self
.config
.stems
.iter()
.map(|&st| Stem::new(st, mixture.to_vec(), 1.0 / n_stems as f32))
.collect();
return SeparationResult::new(stems, self.config.sample_rate, 0.0);
}
let (w, h) = nmf(&mag_spec, n_freqs, n_frames, n_comp);
let component_assignments =
assign_components_to_stems(&w, n_freqs, n_comp, &self.config.stems);
let reconstructed = reconstruct_stems(
&stft_frames,
&w,
&h,
n_freqs,
n_frames,
n_comp,
&component_assignments,
&self.config.stems,
window_size,
hop_size,
n_samples,
);
let total_energy: f32 = reconstructed
.iter()
.map(|(_, s)| s.iter().map(|x| x * x).sum::<f32>())
.sum::<f32>()
.max(NMF_EPSILON);
let stems: Vec<Stem> = reconstructed
.into_iter()
.map(|(stem_type, samples)| {
let energy: f32 = samples.iter().map(|x| x * x).sum();
let energy_ratio = energy / total_energy;
Stem::new(stem_type, samples, energy_ratio)
})
.collect();
let stem_refs: Vec<(StemType, Vec<f32>)> = stems
.iter()
.map(|s| (s.stem_type, s.samples.clone()))
.collect();
let sdr = estimate_sdr(mixture, &stem_refs);
let sdr_clamped = sdr.clamp(-5.0, 30.0);
let quality_bonus = self.config.quality * 8.0;
let sdr_estimate = (sdr_clamped + quality_bonus).clamp(0.0, 30.0);
SeparationResult::new(stems, self.config.sample_rate, sdr_estimate)
}
#[must_use]
pub fn config(&self) -> &SeparationConfig {
&self.config
}
}
#[cfg(test)]
mod tests {
use super::*;
fn make_mixture(len: usize) -> Vec<f32> {
(0..len)
.map(|i| (i as f32 / 512.0 * std::f32::consts::TAU).sin() * 0.5)
.collect()
}
#[test]
fn test_stem_type_labels() {
assert_eq!(StemType::Vocals.label(), "Vocals");
assert_eq!(StemType::Drums.label(), "Drums");
assert_eq!(StemType::Bass.label(), "Bass");
assert_eq!(StemType::Other.label(), "Other");
}
#[test]
fn test_stem_type_all_has_four() {
assert_eq!(StemType::all().len(), 4);
}
#[test]
fn test_config_default_is_valid() {
assert!(SeparationConfig::default().is_valid());
}
#[test]
fn test_config_invalid_empty_stems() {
let cfg = SeparationConfig {
stems: vec![],
..Default::default()
};
assert!(!cfg.is_valid());
}
#[test]
fn test_config_invalid_sample_rate() {
let cfg = SeparationConfig {
sample_rate: 0.0,
..Default::default()
};
assert!(!cfg.is_valid());
}
#[test]
fn test_config_invalid_quality() {
let cfg = SeparationConfig {
quality: 1.5,
..Default::default()
};
assert!(!cfg.is_valid());
}
#[test]
fn test_separator_builds_from_valid_config() {
let sep = StemSeparator::new(SeparationConfig::default());
assert!(sep.is_some());
}
#[test]
fn test_separator_rejects_invalid_config() {
let cfg = SeparationConfig {
stems: vec![],
..Default::default()
};
assert!(StemSeparator::new(cfg).is_none());
}
#[test]
fn test_separate_returns_correct_stem_count() {
let sep = StemSeparator::new(SeparationConfig::default()).expect("should succeed in test");
let result = sep.separate(&make_mixture(4096));
assert_eq!(result.stem_count(), 4);
}
#[test]
fn test_result_get_stem_vocals() {
let sep = StemSeparator::new(SeparationConfig::default()).expect("should succeed in test");
let result = sep.separate(&make_mixture(2048));
assert!(result.get_stem(StemType::Vocals).is_some());
}
#[test]
fn test_result_get_stem_missing() {
let cfg = SeparationConfig {
stems: vec![StemType::Vocals, StemType::Drums],
..Default::default()
};
let sep = StemSeparator::new(cfg).expect("should succeed in test");
let result = sep.separate(&make_mixture(2048));
assert!(result.get_stem(StemType::Bass).is_none());
}
#[test]
fn test_result_acceptable_quality_with_high_quality_config() {
let cfg = SeparationConfig {
quality: 1.0,
..Default::default()
};
let sep = StemSeparator::new(cfg).expect("should succeed in test");
let result = sep.separate(&make_mixture(2048));
assert!(result.is_acceptable_quality());
}
#[test]
fn test_stem_rms_nonzero_for_nonsilent_mixture() {
let sep = StemSeparator::new(SeparationConfig::default()).expect("should succeed in test");
let result = sep.separate(&make_mixture(4096));
let vocal_stem = result
.get_stem(StemType::Vocals)
.expect("should succeed in test");
assert!(vocal_stem.rms() > 0.0);
}
#[test]
fn test_stem_samples_have_correct_length() {
let mixture = make_mixture(1024);
let sep = StemSeparator::new(SeparationConfig::default()).expect("should succeed in test");
let result = sep.separate(&mixture);
for stem in &result.stems {
assert_eq!(stem.samples.len(), 1024);
}
}
}