use crate::analysis::fft::fft_in_place;
use crate::buffer::AudioBuffer;
const WINDOW_SIZE: usize = 2048;
const HOP_SIZE: usize = WINDOW_SIZE / 2; const NUM_BINS: usize = WINDOW_SIZE / 2;
pub fn noise_reduce(buf: &mut AudioBuffer, strength: f32) {
let mut reducer = NoiseReducer::new();
reducer.process(buf, strength);
}
#[must_use]
#[derive(Debug, Clone)]
pub struct NoiseReducer {
window: Vec<f64>,
real: Vec<f64>,
imag: Vec<f64>,
avg_magnitude: Vec<f64>,
}
impl NoiseReducer {
pub fn new() -> Self {
tracing::debug!("NoiseReducer::new");
let window: Vec<f64> = (0..WINDOW_SIZE)
.map(|i| {
0.5 * (1.0
- (2.0 * std::f64::consts::PI * i as f64 / (WINDOW_SIZE - 1) as f64).cos())
})
.collect();
Self {
window,
real: vec![0.0f64; WINDOW_SIZE],
imag: vec![0.0f64; WINDOW_SIZE],
avg_magnitude: vec![0.0f64; NUM_BINS],
}
}
pub fn process(&mut self, buf: &mut AudioBuffer, strength: f32) {
tracing::debug!(
frames = buf.frames,
channels = buf.channels,
strength,
"NoiseReducer::process"
);
let strength = strength.clamp(0.0, 1.0);
if buf.frames < WINDOW_SIZE {
let threshold = strength * 0.05;
for s in &mut buf.samples {
if s.abs() < threshold {
*s = 0.0;
}
}
return;
}
let ch = buf.channels as usize;
for c in 0..ch {
let mut mono: Vec<f32> = (0..buf.frames).map(|f| buf.samples[f * ch + c]).collect();
self.process_channel(&mut mono, strength);
for (f, &sample) in mono.iter().enumerate() {
buf.samples[f * ch + c] = sample;
}
}
}
fn process_channel(&mut self, samples: &mut [f32], strength: f32) {
let n = samples.len();
if n < WINDOW_SIZE {
return;
}
self.avg_magnitude.fill(0.0);
let mut frame_count = 0usize;
let mut pos = 0;
while pos + WINDOW_SIZE <= n {
self.real.fill(0.0);
self.imag.fill(0.0);
for i in 0..WINDOW_SIZE {
self.real[i] = samples[pos + i] as f64 * self.window[i];
}
if !fft_in_place(&mut self.real, &mut self.imag) {
pos += HOP_SIZE;
continue;
}
for k in 0..NUM_BINS {
self.avg_magnitude[k] +=
(self.real[k] * self.real[k] + self.imag[k] * self.imag[k]).sqrt();
}
frame_count += 1;
pos += HOP_SIZE;
}
if frame_count == 0 {
return;
}
for m in &mut self.avg_magnitude {
*m /= frame_count as f64;
}
let mut output = vec![0.0f64; n];
let mut window_sum = vec![0.0f64; n];
pos = 0;
while pos + WINDOW_SIZE <= n {
self.real.fill(0.0);
self.imag.fill(0.0);
for i in 0..WINDOW_SIZE {
self.real[i] = samples[pos + i] as f64 * self.window[i];
}
if !fft_in_place(&mut self.real, &mut self.imag) {
pos += HOP_SIZE;
continue;
}
let gate_factor = strength as f64 * 1.5;
for k in 0..NUM_BINS {
let mag = (self.real[k] * self.real[k] + self.imag[k] * self.imag[k]).sqrt();
let threshold = self.avg_magnitude[k] * gate_factor;
if mag < threshold && threshold > 0.0 {
let attenuation = mag / threshold; self.real[k] *= attenuation;
self.imag[k] *= attenuation;
if k > 0 && k < NUM_BINS {
let mirror = WINDOW_SIZE - k;
self.real[mirror] *= attenuation;
self.imag[mirror] *= attenuation;
}
}
}
for v in &mut self.imag {
*v = -*v;
}
if !fft_in_place(&mut self.real, &mut self.imag) {
pos += HOP_SIZE;
continue;
}
let scale = 1.0 / WINDOW_SIZE as f64;
for r in self.real.iter_mut() {
*r *= scale;
}
for i in 0..WINDOW_SIZE {
output[pos + i] += self.real[i] * self.window[i];
window_sum[pos + i] += self.window[i] * self.window[i];
}
pos += HOP_SIZE;
}
for i in 0..n {
if window_sum[i] > 1e-10 {
samples[i] = (output[i] / window_sum[i]) as f32;
}
if !samples[i].is_finite() {
samples[i] = 0.0;
}
}
}
}
impl Default for NoiseReducer {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn silence_unchanged() {
let mut buf = AudioBuffer::silence(1, 4096, 44100);
noise_reduce(&mut buf, 0.5);
assert!(buf.peak() < f32::EPSILON);
}
#[test]
fn loud_signal_preserved() {
let sr = 44100u32;
let samples: Vec<f32> = (0..sr as usize)
.map(|i| 0.8 * (2.0 * std::f32::consts::PI * 440.0 * i as f32 / sr as f32).sin())
.collect();
let mut buf = AudioBuffer::from_interleaved(samples, 1, sr).unwrap();
let original_rms = buf.rms();
noise_reduce(&mut buf, 0.3);
assert!(
buf.rms() > original_rms * 0.7,
"Loud signal should survive: rms={} vs original={}",
buf.rms(),
original_rms
);
}
#[test]
fn noise_reduced() {
let sr = 44100u32;
let samples: Vec<f32> = (0..sr as usize)
.map(|i| {
let signal =
0.5 * (2.0 * std::f32::consts::PI * 440.0 * i as f32 / sr as f32).sin();
let noise = 0.02 * ((i as f32 * 12_345.679).sin());
signal + noise
})
.collect();
let mut buf = AudioBuffer::from_interleaved(samples, 1, sr).unwrap();
noise_reduce(&mut buf, 0.5);
assert!(buf.samples.iter().all(|s| s.is_finite()));
assert!(buf.rms() > 0.0, "Signal should survive");
}
#[test]
fn short_buffer_fallback() {
let mut buf = AudioBuffer::from_interleaved(vec![0.01; 100], 1, 44100).unwrap();
noise_reduce(&mut buf, 0.5);
assert!(buf.samples.iter().all(|s| s.is_finite()));
}
#[test]
fn stereo_processing() {
let samples: Vec<f32> = (0..88200)
.map(|i| 0.5 * (2.0 * std::f32::consts::PI * 440.0 * (i / 2) as f32 / 44100.0).sin())
.collect();
let mut buf = AudioBuffer::from_interleaved(samples, 2, 44100).unwrap();
noise_reduce(&mut buf, 0.3);
assert!(buf.samples.iter().all(|s| s.is_finite()));
assert!(buf.rms() > 0.0);
}
#[test]
fn output_finite() {
let samples: Vec<f32> = (0..44100)
.map(|i| (i as f32 / 44100.0) * 2.0 - 1.0)
.collect();
let mut buf = AudioBuffer::from_interleaved(samples, 1, 44100).unwrap();
noise_reduce(&mut buf, 1.0);
assert!(buf.samples.iter().all(|s| s.is_finite()));
}
}