use super::*;
use rustfft::num_complex::Complex;
#[derive(Clone, Debug)]
pub struct SpectralInvert {
stft: STFT,
fft_size: usize,
_sample_rate: f32,
mix: f32,
enabled: bool,
}
impl SpectralInvert {
pub fn new(fft_size: usize, hop_size: usize, window_type: WindowType, sample_rate: f32) -> Self {
assert!(fft_size.is_power_of_two(), "FFT size must be power of 2");
assert!(hop_size <= fft_size, "Hop size must be <= FFT size");
assert!(sample_rate > 0.0, "Sample rate must be positive");
let stft = STFT::new(fft_size, hop_size, window_type);
Self {
stft,
fft_size,
_sample_rate: sample_rate,
mix: 1.0,
enabled: true,
}
}
pub fn set_mix(&mut self, mix: f32) {
self.mix = mix.clamp(0.0, 1.0);
}
pub fn mix(&self) -> f32 {
self.mix
}
pub fn process(&mut self, output: &mut [f32], _input: &[f32]) {
if !self.enabled {
return;
}
let mix = self.mix;
self.stft.process(output, |spectrum| {
Self::apply_invert_static(spectrum, mix);
});
}
#[inline]
fn apply_invert_static(spectrum: &mut [Complex<f32>], mix: f32) {
let len = spectrum.len();
let mut dry_spectrum = vec![Complex::new(0.0, 0.0); len];
dry_spectrum.copy_from_slice(spectrum);
let mut inverted = vec![Complex::new(0.0, 0.0); len];
inverted[0] = spectrum[0];
if len > 1 {
inverted[len - 1] = spectrum[len - 1];
}
for i in 1..(len - 1) {
inverted[i] = spectrum[len - 1 - i];
}
spectrum.copy_from_slice(&inverted);
if mix < 1.0 {
for i in 0..len {
spectrum[i] = Complex::new(
spectrum[i].re * mix + dry_spectrum[i].re * (1.0 - mix),
spectrum[i].im * mix + dry_spectrum[i].im * (1.0 - mix),
);
}
}
}
pub fn reset(&mut self) {
self.stft.reset();
}
pub fn fft_size(&self) -> usize {
self.fft_size
}
pub fn hop_size(&self) -> usize {
self.stft.hop_size
}
pub fn set_enabled(&mut self, enabled: bool) {
self.enabled = enabled;
}
pub fn is_enabled(&self) -> bool {
self.enabled
}
}
impl SpectralInvert {
pub fn subtle() -> Self {
let mut invert = Self::new(2048, 512, WindowType::Hann, 44100.0);
invert.set_mix(0.3);
invert
}
pub fn full() -> Self {
let mut invert = Self::new(2048, 512, WindowType::Hann, 44100.0);
invert.set_mix(1.0);
invert
}
pub fn moderate() -> Self {
let mut invert = Self::new(2048, 512, WindowType::Hann, 44100.0);
invert.set_mix(0.6);
invert
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_spectral_invert_creation() {
let invert = SpectralInvert::new(2048, 512, WindowType::Hann, 44100.0);
assert!(invert.is_enabled());
assert_eq!(invert.fft_size(), 2048);
assert_eq!(invert.hop_size(), 512);
assert_eq!(invert.mix(), 1.0);
}
#[test]
#[should_panic(expected = "FFT size must be power of 2")]
fn test_spectral_invert_requires_power_of_two() {
SpectralInvert::new(1000, 250, WindowType::Hann, 44100.0);
}
#[test]
#[should_panic(expected = "Hop size must be <= FFT size")]
fn test_spectral_invert_hop_validation() {
SpectralInvert::new(512, 1024, WindowType::Hann, 44100.0);
}
#[test]
#[should_panic(expected = "Sample rate must be positive")]
fn test_spectral_invert_sample_rate_validation() {
SpectralInvert::new(512, 128, WindowType::Hann, 0.0);
}
#[test]
fn test_set_mix() {
let mut invert = SpectralInvert::new(512, 128, WindowType::Hann, 44100.0);
invert.set_mix(0.5);
assert_eq!(invert.mix(), 0.5);
invert.set_mix(1.5);
assert_eq!(invert.mix(), 1.0);
invert.set_mix(-0.5);
assert_eq!(invert.mix(), 0.0);
}
#[test]
fn test_enable_disable() {
let mut invert = SpectralInvert::new(512, 128, WindowType::Hann, 44100.0);
assert!(invert.is_enabled());
invert.set_enabled(false);
assert!(!invert.is_enabled());
invert.set_enabled(true);
assert!(invert.is_enabled());
}
#[test]
fn test_process_disabled() {
let mut invert = SpectralInvert::new(512, 128, WindowType::Hann, 44100.0);
invert.set_enabled(false);
let input = vec![0.1; 512];
let mut output = vec![0.0; 512];
invert.process(&mut output, &input);
assert!(output.iter().all(|&x| x == 0.0));
}
#[test]
fn test_process_basic() {
let mut invert = SpectralInvert::new(512, 128, WindowType::Hann, 44100.0);
let input = vec![0.1; 512];
let mut output = vec![0.0; 512];
invert.process(&mut output, &input);
}
#[test]
fn test_reset() {
let mut invert = SpectralInvert::new(512, 128, WindowType::Hann, 44100.0);
let input = vec![0.1; 512];
let mut output = vec![0.0; 512];
invert.process(&mut output, &input);
invert.reset();
invert.process(&mut output, &input);
}
#[test]
fn test_full_inversion() {
let mut invert = SpectralInvert::new(512, 128, WindowType::Hann, 44100.0);
invert.set_mix(1.0);
let input = vec![0.1; 512];
let mut output = vec![0.0; 512];
invert.process(&mut output, &input);
}
#[test]
fn test_partial_inversion() {
let mut invert = SpectralInvert::new(512, 128, WindowType::Hann, 44100.0);
invert.set_mix(0.5);
let input = vec![0.1; 512];
let mut output = vec![0.0; 512];
invert.process(&mut output, &input);
}
#[test]
fn test_preset_subtle() {
let invert = SpectralInvert::subtle();
assert_eq!(invert.mix(), 0.3);
assert!(invert.is_enabled());
}
#[test]
fn test_preset_full() {
let invert = SpectralInvert::full();
assert_eq!(invert.mix(), 1.0);
assert!(invert.is_enabled());
}
#[test]
fn test_preset_moderate() {
let invert = SpectralInvert::moderate();
assert_eq!(invert.mix(), 0.6);
assert!(invert.is_enabled());
}
#[test]
fn test_different_fft_sizes() {
let invert_512 = SpectralInvert::new(512, 128, WindowType::Hann, 44100.0);
let invert_1024 = SpectralInvert::new(1024, 256, WindowType::Hann, 44100.0);
let invert_2048 = SpectralInvert::new(2048, 512, WindowType::Hann, 44100.0);
assert_eq!(invert_512.fft_size(), 512);
assert_eq!(invert_1024.fft_size(), 1024);
assert_eq!(invert_2048.fft_size(), 2048);
}
}