use rustfft::{FftPlanner, num_complex::Complex};
pub const DEFAULT_OVERSUBTRACTION: f32 = 1.5;
pub const DEFAULT_SPECTRAL_FLOOR: f32 = 0.1;
pub const DEFAULT_PROFILE_MS: u32 = 200;
const WINDOW_MS: usize = 50;
pub fn denoise(samples: &mut [i16], sample_rate: u32) {
denoise_with_params(
samples,
sample_rate,
DEFAULT_PROFILE_MS,
DEFAULT_OVERSUBTRACTION,
DEFAULT_SPECTRAL_FLOOR,
);
}
pub fn denoise_with_params(
samples: &mut [i16],
sample_rate: u32,
profile_ms: u32,
oversubtraction: f32,
spectral_floor: f32,
) {
let window_size = (sample_rate as usize * WINDOW_MS) / 1000;
if window_size < 4 || samples.is_empty() {
return;
}
let hop = window_size / 2;
let half = window_size / 2;
let profile_samples = ((sample_rate as usize * profile_ms as usize) / 1000)
.min(samples.len())
.max(window_size);
let hann = hann_periodic(window_size);
let mut planner = FftPlanner::<f32>::new();
let fft_forward = planner.plan_fft_forward(window_size);
let fft_inverse = planner.plan_fft_inverse(window_size);
let n_profile_frames = profile_samples.div_ceil(hop);
let mut noise_power = vec![0f32; half];
let mut profile_count = 0usize;
for frame_idx in 0..n_profile_frames {
let start = frame_idx * hop;
let frame_end = start + window_size;
if start >= samples.len() || frame_end > profile_samples {
break;
}
let mut buf: Vec<Complex<f32>> = (0..window_size)
.map(|i| {
let s = if start + i < samples.len() {
samples[start + i] as f32
} else {
0.0
};
Complex {
re: s * hann[i],
im: 0.0,
}
})
.collect();
fft_forward.process(&mut buf);
for (k, c) in buf[..half].iter().enumerate() {
noise_power[k] += c.norm_sqr();
}
profile_count += 1;
}
if profile_count > 0 {
for p in noise_power.iter_mut() {
*p /= profile_count as f32;
}
}
let n = samples.len();
let n_frames = n.div_ceil(hop);
let mut output = vec![0f32; n];
let mut norm = vec![0f32; n];
for frame_idx in 0..n_frames {
let start = frame_idx * hop;
if start >= n {
break;
}
let mut buf: Vec<Complex<f32>> = (0..window_size)
.map(|i| {
let s = if start + i < n {
samples[start + i] as f32
} else {
0.0
};
Complex {
re: s * hann[i],
im: 0.0,
}
})
.collect();
fft_forward.process(&mut buf);
for k in 0..half {
let frame_p = buf[k].norm_sqr().max(f32::EPSILON);
let gain = (1.0 - oversubtraction * noise_power[k] / frame_p).max(spectral_floor);
buf[k].re *= gain;
buf[k].im *= gain;
let mirror = window_size - k;
if mirror < window_size && mirror > half {
buf[mirror].re *= gain;
buf[mirror].im *= gain;
}
}
fft_inverse.process(&mut buf);
let scale = 1.0 / window_size as f32;
for i in 0..window_size {
let out_idx = start + i;
if out_idx < n {
output[out_idx] += buf[i].re * scale;
norm[out_idx] += hann[i];
}
}
}
for (i, s) in samples.iter_mut().enumerate() {
if norm[i] > f32::EPSILON {
*s = (output[i] / norm[i])
.round()
.clamp(i16::MIN as f32, i16::MAX as f32) as i16;
}
}
}
fn hann_periodic(n: usize) -> Vec<f32> {
(0..n)
.map(|i| 0.5 * (1.0 - (2.0 * std::f32::consts::PI * i as f32 / n as f32).cos()))
.collect()
}
#[cfg(test)]
mod tests {
use rand::RngExt;
use super::*;
use crate::analyse::rms_db;
const SR: u32 = 24_000;
fn pure_tone(freq_hz: f32, amplitude: f32, secs: f32) -> Vec<i16> {
let n = (SR as f32 * secs) as usize;
(0..n)
.map(|i| {
let v =
amplitude * (2.0 * std::f32::consts::PI * freq_hz * i as f32 / SR as f32).sin();
v.clamp(i16::MIN as f32, i16::MAX as f32) as i16
})
.collect()
}
fn white_noise(amplitude: f32, n: usize) -> Vec<i16> {
let mut rng = rand::rng();
(0..n)
.map(|_| {
let v = (rng.random::<f32>() * 2.0 - 1.0) * amplitude;
v.clamp(i16::MIN as f32, i16::MAX as f32) as i16
})
.collect()
}
#[test]
fn empty_input_is_a_no_op() {
let mut samples: Vec<i16> = Vec::new();
denoise(&mut samples, SR);
}
#[test]
fn speech_tone_preserved_after_denoising() {
let profile_len = (SR as f32 * 0.2) as usize;
let mut samples = white_noise(200.0, profile_len);
samples.extend(pure_tone(440.0, 5_000.0, 0.8));
let speech_start = profile_len;
let before_rms = rms_db(&samples[speech_start..]);
denoise(&mut samples, SR);
let after_rms = rms_db(&samples[speech_start..]);
assert!(
(before_rms - after_rms) < 6.0,
"Speech attenuated too much: before={:.1} after={:.1}",
before_rms,
after_rms
);
}
#[test]
fn noise_is_attenuated_in_steady_state() {
let n = (SR as usize) * 2;
let noise = white_noise(500.0, n);
let mut samples = noise.clone();
denoise_with_params(
&mut samples,
SR,
2_000,
DEFAULT_OVERSUBTRACTION,
DEFAULT_SPECTRAL_FLOOR,
);
let mid_s = (SR as usize * 750) / 1000;
let mid_e = (SR as usize * 1250) / 1000;
let before = rms_db(&noise[mid_s..mid_e]);
let after = rms_db(&samples[mid_s..mid_e]);
assert!(
after < before,
"Noise not attenuated in steady state: before={:.1} after={:.1}",
before,
after
);
}
}