use super::*;
use rustfft::num_complex::Complex;
#[derive(Clone)]
pub struct SpectralGate {
stft: STFT,
fft_size: usize,
sample_rate: f32,
threshold_db: f32,
attack: f32,
release: f32,
ratio: f32,
attack_coeff: f32,
release_coeff: f32,
envelope: Vec<f32>,
enabled: bool,
}
impl SpectralGate {
pub fn new(
fft_size: usize,
hop_size: usize,
window_type: WindowType,
sample_rate: f32,
) -> Self {
assert!(fft_size.is_power_of_two(), "FFT size must be power of 2");
assert!(hop_size <= fft_size, "Hop size must be <= FFT size");
assert!(sample_rate > 0.0, "Sample rate must be positive");
let stft = STFT::new(fft_size, hop_size, window_type);
let attack = 0.001; let release = 0.050;
let mut gate = Self {
stft,
fft_size,
sample_rate,
threshold_db: -40.0, attack,
release,
ratio: 0.0, attack_coeff: 0.0,
release_coeff: 0.0,
envelope: vec![0.0; fft_size],
enabled: true,
};
gate.update_coefficients();
gate
}
pub fn set_threshold(&mut self, threshold_db: f32) {
self.threshold_db = threshold_db;
}
pub fn set_attack(&mut self, attack: f32) {
self.attack = attack.max(0.0001); self.update_coefficients();
}
pub fn set_release(&mut self, release: f32) {
self.release = release.max(0.001); self.update_coefficients();
}
pub fn set_ratio(&mut self, ratio: f32) {
self.ratio = ratio.clamp(0.0, 1.0);
}
pub fn threshold(&self) -> f32 {
self.threshold_db
}
pub fn attack(&self) -> f32 {
self.attack
}
pub fn release(&self) -> f32 {
self.release
}
pub fn ratio(&self) -> f32 {
self.ratio
}
fn update_coefficients(&mut self) {
let hop_time = self.stft.hop_size() as f32 / self.sample_rate;
self.attack_coeff = (-hop_time / self.attack).exp();
self.release_coeff = (-hop_time / self.release).exp();
}
pub fn process(&mut self, output: &mut [f32], _input: &[f32]) {
if !self.enabled {
return;
}
let threshold_db = self.threshold_db;
let ratio = self.ratio;
let attack_coeff = self.attack_coeff;
let release_coeff = self.release_coeff;
let envelope = &mut self.envelope;
self.stft.process(output, |spectrum| {
Self::apply_gate_static(
spectrum,
envelope,
threshold_db,
ratio,
attack_coeff,
release_coeff,
);
});
}
#[inline]
fn apply_gate_static(
spectrum: &mut [Complex<f32>],
envelope: &mut [f32],
threshold_db: f32,
ratio: f32,
attack_coeff: f32,
release_coeff: f32,
) {
let len = spectrum.len();
let mut magnitudes = vec![0.0; len];
ComplexOps::magnitude(&mut magnitudes, spectrum);
for i in 0..len {
let mag_db = if magnitudes[i] > 1e-10 {
20.0 * magnitudes[i].log10()
} else {
-100.0 };
let target_gain = if mag_db >= threshold_db {
1.0 } else {
ratio };
let current_env = envelope[i];
let coeff = if target_gain > current_env {
attack_coeff } else {
release_coeff };
envelope[i] = target_gain + coeff * (current_env - target_gain);
let gain = envelope[i];
spectrum[i].re *= gain;
spectrum[i].im *= gain;
}
}
pub fn reset(&mut self) {
self.stft.reset();
self.envelope.fill(0.0);
}
pub fn fft_size(&self) -> usize {
self.fft_size
}
pub fn hop_size(&self) -> usize {
self.stft.hop_size()
}
pub fn set_enabled(&mut self, enabled: bool) {
self.enabled = enabled;
}
pub fn is_enabled(&self) -> bool {
self.enabled
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_spectral_gate_creation() {
let gate = SpectralGate::new(2048, 512, WindowType::Hann, 44100.0);
assert_eq!(gate.fft_size(), 2048);
assert_eq!(gate.hop_size(), 512);
assert_eq!(gate.threshold(), -40.0);
assert!(gate.is_enabled());
}
#[test]
#[should_panic(expected = "FFT size must be power of 2")]
fn test_spectral_gate_requires_power_of_two() {
SpectralGate::new(1000, 250, WindowType::Hann, 44100.0);
}
#[test]
#[should_panic(expected = "Hop size must be <= FFT size")]
fn test_spectral_gate_hop_validation() {
SpectralGate::new(1024, 2048, WindowType::Hann, 44100.0);
}
#[test]
#[should_panic(expected = "Sample rate must be positive")]
fn test_spectral_gate_sample_rate_validation() {
SpectralGate::new(1024, 256, WindowType::Hann, 0.0);
}
#[test]
fn test_spectral_gate_set_threshold() {
let mut gate = SpectralGate::new(1024, 256, WindowType::Hann, 44100.0);
gate.set_threshold(-30.0);
assert_eq!(gate.threshold(), -30.0);
gate.set_threshold(-60.0);
assert_eq!(gate.threshold(), -60.0);
}
#[test]
fn test_spectral_gate_set_attack() {
let mut gate = SpectralGate::new(1024, 256, WindowType::Hann, 44100.0);
gate.set_attack(0.01);
assert_eq!(gate.attack(), 0.01);
gate.set_attack(0.00001);
assert_eq!(gate.attack(), 0.0001);
}
#[test]
fn test_spectral_gate_set_release() {
let mut gate = SpectralGate::new(1024, 256, WindowType::Hann, 44100.0);
gate.set_release(0.1);
assert_eq!(gate.release(), 0.1);
gate.set_release(0.0001);
assert_eq!(gate.release(), 0.001);
}
#[test]
fn test_spectral_gate_set_ratio() {
let mut gate = SpectralGate::new(1024, 256, WindowType::Hann, 44100.0);
gate.set_ratio(0.5);
assert_eq!(gate.ratio(), 0.5);
gate.set_ratio(1.5);
assert_eq!(gate.ratio(), 1.0);
gate.set_ratio(-0.5);
assert_eq!(gate.ratio(), 0.0);
}
#[test]
fn test_spectral_gate_process_silent() {
let mut gate = SpectralGate::new(1024, 256, WindowType::Hann, 44100.0);
let input = vec![0.0; 512];
let mut output = vec![0.0; 512];
gate.process(&mut output, &input);
for &sample in &output {
assert!(sample.abs() < 0.001, "Expected silence, got {}", sample);
}
}
#[test]
fn test_spectral_gate_process_with_threshold() {
let mut gate = SpectralGate::new(512, 128, WindowType::Hann, 44100.0);
gate.set_threshold(-20.0);
gate.set_ratio(0.0);
let input = vec![0.0; 256];
let mut output = vec![0.0; 256];
gate.process(&mut output, &input);
assert_eq!(output.len(), 256);
}
#[test]
fn test_spectral_gate_disabled() {
let mut gate = SpectralGate::new(512, 128, WindowType::Hann, 44100.0);
gate.set_enabled(false);
let input = vec![1.0; 256];
let mut output = vec![1.0; 256];
gate.process(&mut output, &input);
assert_eq!(output[0], 1.0);
}
#[test]
fn test_spectral_gate_reset() {
let mut gate = SpectralGate::new(512, 128, WindowType::Hann, 44100.0);
let input = vec![0.0; 256];
let mut output = vec![0.0; 256];
gate.process(&mut output, &input);
gate.reset();
gate.process(&mut output, &input);
assert_eq!(output.len(), 256);
}
#[test]
fn test_spectral_gate_all_window_types() {
for window_type in [
WindowType::Rectangular,
WindowType::Hann,
WindowType::Hamming,
WindowType::Blackman,
WindowType::BlackmanHarris,
] {
let mut gate = SpectralGate::new(512, 128, window_type, 44100.0);
let input = vec![0.0; 256];
let mut output = vec![0.0; 256];
gate.process(&mut output, &input);
assert_eq!(output.len(), 256);
}
}
#[test]
fn test_spectral_gate_various_fft_sizes() {
for fft_size in [512, 1024, 2048, 4096] {
let hop_size = fft_size / 4;
let mut gate = SpectralGate::new(fft_size, hop_size, WindowType::Hann, 44100.0);
let input = vec![0.0; 512];
let mut output = vec![0.0; 512];
gate.process(&mut output, &input);
assert_eq!(output.len(), 512);
}
}
#[test]
fn test_spectral_gate_enable_disable() {
let mut gate = SpectralGate::new(512, 128, WindowType::Hann, 44100.0);
assert!(gate.is_enabled());
gate.set_enabled(false);
assert!(!gate.is_enabled());
gate.set_enabled(true);
assert!(gate.is_enabled());
}
}