use super::basis::{Complex64, DFT_SIZE};
const SUBTRACTION_BETA: f64 = 0.1;
const SPECTRAL_FLOOR_ALPHA: f64 = 0.02;
const NOISE_EMA_LAMBDA: f64 = 0.95;
const NOISE_INIT: f64 = 0.0;
const STRICT_ENERGY_RATIO: f64 = 4.0;
const STRICT_CONSECUTIVE_FRAMES: u32 = 2;
const STATIONARITY_COS_THRESHOLD: f64 = 0.999;
const STATIONARITY_REF_LAMBDA: f64 = 0.7;
const DD_SNR_SMOOTHING: f64 = 0.98;
const A_PRIORI_SNR_FLOOR: f64 = 0.003_162_277_660;
const SUPPRESSION_MODE: SuppressionMode = SuppressionMode::Boll;
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum SuppressionMode {
Boll,
#[allow(dead_code)]
Wiener,
}
#[derive(Clone, Debug)]
pub struct NoiseSpectrum {
pub n_psd: [f64; DFT_SIZE],
pub primed: bool,
eligible_run: u32,
spectrum_ref: [f64; DFT_SIZE],
spectrum_ref_primed: bool,
prev_enhanced_psd: [f64; DFT_SIZE],
}
impl NoiseSpectrum {
pub const fn new() -> Self {
Self {
n_psd: [NOISE_INIT; DFT_SIZE],
primed: false,
eligible_run: 0,
spectrum_ref: [0.0; DFT_SIZE],
spectrum_ref_primed: false,
prev_enhanced_psd: [0.0; DFT_SIZE],
}
}
pub fn reset(&mut self) {
self.n_psd = [NOISE_INIT; DFT_SIZE];
self.primed = false;
self.eligible_run = 0;
self.spectrum_ref = [0.0; DFT_SIZE];
self.spectrum_ref_primed = false;
self.prev_enhanced_psd = [0.0; DFT_SIZE];
}
fn stationarity_cos_sim(&self, sw: &[Complex64; DFT_SIZE]) -> f64 {
if !self.spectrum_ref_primed {
return 0.0;
}
let mut dot = 0.0f64;
let mut a_norm = 0.0f64;
let mut b_norm = 0.0f64;
for m in 0..DFT_SIZE {
let a = sw[m].norm_sqr();
let b = self.spectrum_ref[m];
dot += a * b;
a_norm += a * a;
b_norm += b * b;
}
let denom = (a_norm * b_norm).sqrt();
if denom <= 1e-30 {
0.0
} else {
dot / denom
}
}
pub fn update(&mut self, sw: &[Complex64; DFT_SIZE], silent: bool, frame_energy: f64, noise_floor_eta: f64) {
let cos_sim = self.stationarity_cos_sim(sw);
let stationary = cos_sim >= STATIONARITY_COS_THRESHOLD;
if !self.spectrum_ref_primed {
for (m, slot) in self.spectrum_ref.iter_mut().enumerate() {
*slot = sw[m].norm_sqr();
}
self.spectrum_ref_primed = true;
} else {
for (m, slot) in self.spectrum_ref.iter_mut().enumerate() {
let cur = sw[m].norm_sqr();
*slot = STATIONARITY_REF_LAMBDA * *slot
+ (1.0 - STATIONARITY_REF_LAMBDA) * cur;
}
}
let energy_ok = silent || frame_energy < STRICT_ENERGY_RATIO * noise_floor_eta;
if !stationary || !energy_ok {
self.eligible_run = 0;
return;
}
self.eligible_run = self.eligible_run.saturating_add(1);
if self.eligible_run < STRICT_CONSECUTIVE_FRAMES {
return;
}
if !self.primed {
for (m, slot) in self.n_psd.iter_mut().enumerate() {
*slot = sw[m].norm_sqr();
}
self.primed = true;
return;
}
for (m, slot) in self.n_psd.iter_mut().enumerate() {
let cur = sw[m].norm_sqr();
*slot = NOISE_EMA_LAMBDA * *slot + (1.0 - NOISE_EMA_LAMBDA) * cur;
}
}
}
impl Default for NoiseSpectrum {
#[inline]
fn default() -> Self {
Self::new()
}
}
pub fn apply_subtraction(
sw: &[Complex64; DFT_SIZE],
state: &mut NoiseSpectrum,
) -> [Complex64; DFT_SIZE] {
if !state.primed {
return *sw;
}
let mut out = *sw;
match SUPPRESSION_MODE {
SuppressionMode::Boll => {
for m in 0..DFT_SIZE {
let s_psd = sw[m].norm_sqr();
if s_psd <= 1e-30 {
continue;
}
let suppressed = (s_psd - SUBTRACTION_BETA * state.n_psd[m])
.max(SPECTRAL_FLOOR_ALPHA * s_psd);
let gain = (suppressed / s_psd).sqrt();
out[m] = Complex64::new(sw[m].re * gain, sw[m].im * gain);
}
}
SuppressionMode::Wiener => {
for m in 0..DFT_SIZE {
let s_psd = sw[m].norm_sqr();
let n_psd = state.n_psd[m].max(1e-30);
if s_psd <= 1e-30 {
state.prev_enhanced_psd[m] = 0.0;
continue;
}
let gamma = (s_psd / n_psd - 1.0).max(0.0);
let xi = (DD_SNR_SMOOTHING * state.prev_enhanced_psd[m] / n_psd
+ (1.0 - DD_SNR_SMOOTHING) * gamma)
.max(A_PRIORI_SNR_FLOOR);
let raw_gain = xi / (1.0 + xi);
let gain = raw_gain.max(SPECTRAL_FLOOR_ALPHA.sqrt());
out[m] = Complex64::new(sw[m].re * gain, sw[m].im * gain);
state.prev_enhanced_psd[m] = gain * gain * s_psd;
}
}
}
out
}
#[cfg(test)]
mod tests {
use super::*;
fn flat_spectrum(magnitude: f64) -> [Complex64; DFT_SIZE] {
let mut sw = [Complex64::new(0.0, 0.0); DFT_SIZE];
for s in &mut sw {
*s = Complex64::new(magnitude, 0.0);
}
sw
}
#[test]
fn unprimed_noise_state_is_passthrough() {
let sw = flat_spectrum(2.0);
let mut state = NoiseSpectrum::new();
let out = apply_subtraction(&sw, &mut state);
for m in 0..DFT_SIZE {
assert!((out[m].re - sw[m].re).abs() < 1e-12);
}
}
fn prime_with(state: &mut NoiseSpectrum, sw: &[Complex64; DFT_SIZE]) {
for _ in 0..(STRICT_CONSECUTIVE_FRAMES + 1) {
state.update(sw, true, 0.0, 1.0);
}
}
#[test]
fn strict_gate_requires_consecutive_stationary_frames() {
let mut state = NoiseSpectrum::new();
let noise = flat_spectrum(0.5);
for _ in 0..STRICT_CONSECUTIVE_FRAMES {
state.update(&noise, true, 0.0, 1.0);
}
assert!(!state.primed);
state.update(&noise, true, 0.0, 1.0);
assert!(state.primed);
for m in 0..DFT_SIZE {
assert!((state.n_psd[m] - 0.25).abs() < 1e-9);
}
}
#[test]
fn non_stationary_frame_breaks_eligible_run() {
let mut state = NoiseSpectrum::new();
let noise = flat_spectrum(0.5);
for _ in 0..STRICT_CONSECUTIVE_FRAMES {
state.update(&noise, true, 0.0, 1.0);
}
let mut shifted = flat_spectrum(0.0);
shifted[10] = Complex64::new(100.0, 0.0);
state.update(&shifted, true, 0.0, 1.0);
assert!(!state.primed);
}
#[test]
fn voiced_frames_hold_noise_estimate() {
let mut state = NoiseSpectrum::new();
let noise = flat_spectrum(0.5);
prime_with(&mut state, &noise);
let baseline = state.n_psd[0];
let mut voice = flat_spectrum(0.0);
voice[10] = Complex64::new(50.0, 0.0);
state.update(&voice, false, 100.0, 1.0);
for m in 0..DFT_SIZE {
assert!((state.n_psd[m] - baseline).abs() < 1e-9,
"voiced update leaked: bin {m} = {}", state.n_psd[m]);
}
}
#[test]
fn subtraction_attenuates_signal_at_or_below_noise_psd() {
let mut state = NoiseSpectrum::new();
let train = flat_spectrum(0.5);
prime_with(&mut state, &train);
let probe = flat_spectrum(0.5);
let out = apply_subtraction(&probe, &mut state);
let expected_gain = ((1.0 - SUBTRACTION_BETA).max(SPECTRAL_FLOOR_ALPHA)).sqrt();
let expected_max = expected_gain * 0.5 + 1e-9;
for m in 0..DFT_SIZE {
assert!(out[m].re.abs() <= expected_max,
"bin {m} expected ≤ {expected_max:.4}, got {}", out[m].re);
assert!(out[m].re.abs() < 0.5,
"bin {m} not attenuated at all: {}", out[m].re);
}
}
#[test]
fn subtraction_passes_strong_signal_through() {
let mut state = NoiseSpectrum::new();
let noise = flat_spectrum(0.1);
prime_with(&mut state, &noise);
let signal = flat_spectrum(10.0);
let out = apply_subtraction(&signal, &mut state);
for m in 0..DFT_SIZE {
assert!((out[m].re - 10.0).abs() < 0.5,
"bin {m} over-attenuated: {}", out[m].re);
assert!(out[m].re.abs() > 5.0,
"bin {m} excessively attenuated: {}", out[m].re);
}
}
}