use crate::analyse::rms_db;
use rustfft::{FftPlanner, num_complex::Complex};
pub const DEFAULT_ATTENUATION_DB: f32 = 6.0;
const PLOSIVE_HI_HZ: f32 = 150.0;
const WINDOW_MS: usize = 50;
const MIN_SPEECH_RMS_DB: f32 = -50.0;
pub fn suppress_plosives(samples: &mut [i16], sample_rate: u32) {
suppress_plosives_with_attenuation(samples, sample_rate, DEFAULT_ATTENUATION_DB);
}
pub fn suppress_plosives_with_attenuation(
samples: &mut [i16],
sample_rate: u32,
attenuation_db: f32,
) {
let window_size = (sample_rate as usize * WINDOW_MS) / 1000;
if window_size < 4 || samples.is_empty() {
return;
}
let hop = window_size / 2;
let half = window_size / 2;
let freq_res = sample_rate as f32 / window_size as f32;
let hi_bin = ((PLOSIVE_HI_HZ / freq_res) as usize).min(half);
let gain = 10f32.powf(-attenuation_db / 20.0);
let plo_threshold = crate::spectral::PLOSIVE_RATIO_THRESHOLD;
let hann = hann_periodic(window_size);
let mut planner = FftPlanner::<f32>::new();
let fft = planner.plan_fft_forward(window_size);
let ifft = planner.plan_fft_inverse(window_size);
let n_frames = samples.len().div_ceil(hop);
let mut output = vec![0f32; samples.len()];
let mut norm = vec![0f32; samples.len()];
let mut frame_buf = vec![0i16; window_size];
for frame_idx in 0..n_frames {
let start = frame_idx * hop;
if start >= samples.len() {
break;
}
for i in 0..window_size {
frame_buf[i] = if start + i < samples.len() {
samples[start + i]
} else {
0
};
}
let mut buffer: Vec<Complex<f32>> = frame_buf
.iter()
.zip(hann.iter())
.map(|(&s, &w)| Complex {
re: s as f32 * w,
im: 0.0,
})
.collect();
fft.process(&mut buffer);
let apply = if hi_bin >= 2 && rms_db(&frame_buf) >= MIN_SPEECH_RMS_DB {
let power: Vec<f32> = buffer[..half].iter().map(|c| c.norm_sqr()).collect();
let total: f32 = power.iter().sum();
if total > f32::EPSILON {
let plo_energy: f32 = power[1..hi_bin].iter().sum();
(plo_energy / total) > plo_threshold
} else {
false
}
} else {
false
};
if apply {
for i in 1..hi_bin {
buffer[i].re *= gain;
buffer[i].im *= gain;
let mirror = window_size - i;
if mirror < window_size && mirror > half {
buffer[mirror].re *= gain;
buffer[mirror].im *= gain;
}
}
}
ifft.process(&mut buffer);
let scale = 1.0 / window_size as f32;
for i in 0..window_size {
let out_idx = start + i;
if out_idx < output.len() {
output[out_idx] += buffer[i].re * scale;
norm[out_idx] += hann[i];
}
}
}
for (i, s) in samples.iter_mut().enumerate() {
let n = norm[i];
if n > f32::EPSILON {
*s = (output[i] / n)
.round()
.clamp(i16::MIN as f32, i16::MAX as f32) as i16;
}
}
}
fn hann_periodic(n: usize) -> Vec<f32> {
(0..n)
.map(|i| 0.5 * (1.0 - (2.0 * std::f32::consts::PI * i as f32 / n as f32).cos()))
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
const SR: u32 = 24_000;
fn pure_tone(freq_hz: f32, amplitude: f32, secs: f32) -> Vec<i16> {
let n = (SR as f32 * secs) as usize;
(0..n)
.map(|i| {
let v =
amplitude * (2.0 * std::f32::consts::PI * freq_hz * i as f32 / SR as f32).sin();
v.clamp(i16::MIN as f32, i16::MAX as f32) as i16
})
.collect()
}
#[test]
fn passthrough_preserves_speech_band() {
let original = pure_tone(1000.0, 8_000.0, 0.2);
let mut processed = original.clone();
suppress_plosives(&mut processed, SR);
let orig_rms: f32 = {
let sq: f32 = original.iter().map(|&s| (s as f32).powi(2)).sum();
(sq / original.len() as f32).sqrt()
};
let proc_rms: f32 = {
let sq: f32 = processed.iter().map(|&s| (s as f32).powi(2)).sum();
(sq / processed.len() as f32).sqrt()
};
let diff_db = 20.0 * (proc_rms / orig_rms.max(1.0)).log10();
assert!(
diff_db.abs() < 0.5,
"Speech band altered by {:.2} dB",
diff_db
);
}
#[test]
fn plosive_tone_is_attenuated() {
let mut plosive = pure_tone(80.0, 8_000.0, 0.5);
let rms_before: f32 = {
let sq: f32 = plosive.iter().map(|&s| (s as f32).powi(2)).sum();
(sq / plosive.len() as f32).sqrt()
};
suppress_plosives(&mut plosive, SR);
let rms_after: f32 = {
let sq: f32 = plosive.iter().map(|&s| (s as f32).powi(2)).sum();
(sq / plosive.len() as f32).sqrt()
};
assert!(
rms_after < rms_before,
"Plosive tone not attenuated (before={:.0}, after={:.0})",
rms_before,
rms_after
);
}
#[test]
fn empty_input_is_a_no_op() {
let mut samples: Vec<i16> = Vec::new();
suppress_plosives(&mut samples, SR); }
}