Skip to main content

any_tts/audio/
denoise.rs

1use std::cmp::Ordering;
2use std::f32::consts::PI;
3
4use rustfft::num_complex::Complex32;
5use rustfft::{Fft, FftPlanner};
6
7use super::AudioSamples;
8
9/// Parameters for the speech-focused denoiser.
10///
11/// The algorithm combines a speech-band filter, a short-time spectral
12/// suppressor, and an adaptive residual-noise gate. It is intended to
13/// attenuate steady background noise and background music, not to perform
14/// full source separation.
15#[derive(Debug, Clone, Copy)]
16pub struct DenoiseOptions {
17    /// FFT frame size used by the spectral gate.
18    pub frame_size: usize,
19    /// Frame hop size in samples.
20    pub hop_size: usize,
21    /// Lower cutoff for the speech-band high-pass filter.
22    pub speech_low_hz: f32,
23    /// Upper cutoff for the speech-band low-pass filter.
24    pub speech_high_hz: f32,
25    /// Fraction of the quietest frames used to estimate the noise profile.
26    pub noise_estimation_percentile: f32,
27    /// Multiplier applied to the estimated noise spectrum.
28    pub noise_reduction: f32,
29    /// Residual spectral floor kept to reduce musical-noise artifacts.
30    pub residual_floor: f32,
31    /// Blend between denoised output and the speech-band filtered signal.
32    pub wet_mix: f32,
33}
34
35impl Default for DenoiseOptions {
36    fn default() -> Self {
37        Self {
38            frame_size: 1024,
39            hop_size: 256,
40            speech_low_hz: 110.0,
41            speech_high_hz: 5_800.0,
42            noise_estimation_percentile: 0.2,
43            noise_reduction: 1.35,
44            residual_floor: 0.08,
45            wet_mix: 0.9,
46        }
47    }
48}
49
50pub(super) fn denoise_audio_samples(audio: &AudioSamples, options: DenoiseOptions) -> AudioSamples {
51    if audio.is_empty() || audio.sample_rate == 0 {
52        return audio.clone();
53    }
54
55    let config = SanitizedDenoiseOptions::new(audio.sample_rate, options);
56    let filtered = apply_speech_bandpass(&audio.samples, &config);
57    let window = hann_window(config.frame_size);
58    let mut planner = FftPlanner::<f32>::new();
59    let fft = planner.plan_fft_forward(config.frame_size);
60    let ifft = planner.plan_fft_inverse(config.frame_size);
61    let noise_estimate = estimate_noise_profile(&filtered, &config, &window, fft.as_ref());
62    let cleaned = render_denoised_samples(
63        &filtered,
64        &config,
65        &window,
66        fft.as_ref(),
67        ifft.as_ref(),
68        &noise_estimate,
69    );
70
71    AudioSamples::new(cleaned, audio.sample_rate)
72}
73
74fn estimate_noise_profile(
75    samples: &[f32],
76    config: &SanitizedDenoiseOptions,
77    window: &[f32],
78    fft: &dyn Fft<f32>,
79) -> NoiseEstimate {
80    let selected_offsets = select_quiet_frame_offsets(samples, config);
81    let mut profile = vec![0.0; config.frame_size];
82    let mut buffer = vec![Complex32::default(); config.frame_size];
83    let mut quiet_rms_sum = 0.0;
84
85    for &start in &selected_offsets {
86        quiet_rms_sum += frame_rms(samples, start, config.frame_size);
87        load_windowed_frame(samples, start, window, &mut buffer);
88        fft.process(&mut buffer);
89        for (value, spectrum) in profile.iter_mut().zip(&buffer) {
90            *value += spectrum.norm();
91        }
92    }
93
94    let frame_count = selected_offsets.len().max(1) as f32;
95    normalize_and_smooth(&mut profile, frame_count);
96    NoiseEstimate {
97        spectrum: profile,
98        quiet_rms: quiet_rms_sum / frame_count,
99    }
100}
101
102fn select_quiet_frame_offsets(samples: &[f32], config: &SanitizedDenoiseOptions) -> Vec<usize> {
103    let mut ranked: Vec<(usize, f32)> =
104        frame_offsets(samples.len(), config.frame_size, config.hop_size)
105            .into_iter()
106            .map(|start| (start, frame_rms(samples, start, config.frame_size)))
107            .collect();
108    ranked.sort_by(|left, right| left.1.partial_cmp(&right.1).unwrap_or(Ordering::Equal));
109
110    ranked
111        .into_iter()
112        .take(quiet_frame_count(samples, config))
113        .map(|(start, _)| start)
114        .collect()
115}
116
117fn quiet_frame_count(samples: &[f32], config: &SanitizedDenoiseOptions) -> usize {
118    let total_frames = frame_offsets(samples.len(), config.frame_size, config.hop_size)
119        .len()
120        .max(1);
121    ((total_frames as f32 * config.noise_estimation_percentile).ceil() as usize)
122        .clamp(1, total_frames)
123}
124
125fn render_denoised_samples(
126    samples: &[f32],
127    config: &SanitizedDenoiseOptions,
128    window: &[f32],
129    fft: &dyn Fft<f32>,
130    ifft: &dyn Fft<f32>,
131    noise_estimate: &NoiseEstimate,
132) -> Vec<f32> {
133    let offsets = frame_offsets(samples.len(), config.frame_size, config.hop_size);
134    let mut overlap_add = vec![0.0; samples.len() + config.frame_size];
135    let mut normalization = vec![0.0; samples.len() + config.frame_size];
136    let mut buffer = vec![Complex32::default(); config.frame_size];
137    let mut mask = vec![0.0; config.frame_size];
138    let mut adaptive_noise = noise_estimate.spectrum.clone();
139    let mut previous_mask = vec![1.0; config.frame_size];
140
141    for start in offsets {
142        load_windowed_frame(samples, start, window, &mut buffer);
143        fft.process(&mut buffer);
144        update_adaptive_noise_profile(
145            &buffer,
146            &noise_estimate.spectrum,
147            &mut adaptive_noise,
148            frame_rms(samples, start, config.frame_size),
149            noise_estimate.quiet_rms,
150            config,
151        );
152        build_spectral_mask(&buffer, &adaptive_noise, &previous_mask, config, &mut mask);
153        previous_mask.clone_from_slice(&mask);
154        apply_mask(&mut buffer, &mask);
155        ifft.process(&mut buffer);
156        overlap_add_frame(&buffer, start, window, &mut overlap_add, &mut normalization);
157    }
158
159    finalize_samples(
160        samples,
161        &overlap_add,
162        &normalization,
163        config,
164        noise_estimate.quiet_rms,
165    )
166}
167
168fn build_spectral_mask(
169    spectrum: &[Complex32],
170    noise_profile: &[f32],
171    previous_mask: &[f32],
172    config: &SanitizedDenoiseOptions,
173    mask: &mut [f32],
174) {
175    const EPSILON: f32 = 1e-6;
176
177    for index in 0..mask.len() {
178        let magnitude = spectrum[index].norm();
179        let noise = noise_profile[index].max(EPSILON);
180        let power = magnitude * magnitude;
181        let noise_power = noise * noise;
182        let wiener =
183            (1.0 - config.noise_reduction * noise_power / (power + EPSILON)).clamp(0.0, 1.0);
184        let snr_db = 10.0 * ((power + EPSILON) / (noise_power + EPSILON)).log10();
185        let soft_gate = ((snr_db + 3.0) / 12.0).clamp(0.0, 1.0);
186        let harmonic_ratio = magnitude / noise;
187        let harmonic_guard = ((harmonic_ratio - 0.75) / 2.25).clamp(0.0, 1.0);
188        let gain = wiener * (0.2 + 0.8 * soft_gate) * (0.25 + 0.75 * harmonic_guard);
189        let temporal = if gain > previous_mask[index] {
190            previous_mask[index] * 0.25 + gain * 0.75
191        } else {
192            previous_mask[index] * 0.4 + gain * 0.6
193        };
194        mask[index] = config.residual_floor + (1.0 - config.residual_floor) * temporal;
195    }
196    smooth_in_place(mask);
197    smooth_in_place(mask);
198}
199
200fn update_adaptive_noise_profile(
201    spectrum: &[Complex32],
202    base_noise: &[f32],
203    adaptive_noise: &mut [f32],
204    frame_rms: f32,
205    quiet_rms: f32,
206    config: &SanitizedDenoiseOptions,
207) {
208    let quiet_threshold = quiet_rms.max(1e-5) * (1.4 + config.noise_estimation_percentile);
209    let update_rate = if frame_rms <= quiet_threshold {
210        0.18
211    } else {
212        0.035
213    };
214    let growth_limit = if frame_rms <= quiet_threshold {
215        2.0
216    } else {
217        1.15
218    };
219
220    for index in 0..adaptive_noise.len() {
221        let capped_magnitude = spectrum[index]
222            .norm()
223            .min(base_noise[index] * growth_limit + 1e-5);
224        adaptive_noise[index] =
225            adaptive_noise[index] * (1.0 - update_rate) + capped_magnitude * update_rate;
226        adaptive_noise[index] = adaptive_noise[index].max(base_noise[index] * 0.5);
227    }
228
229    smooth_in_place(adaptive_noise);
230}
231
232fn apply_mask(spectrum: &mut [Complex32], mask: &[f32]) {
233    for (bin, value) in spectrum.iter_mut().zip(mask) {
234        *bin *= *value;
235    }
236}
237
238fn overlap_add_frame(
239    frame: &[Complex32],
240    start: usize,
241    window: &[f32],
242    overlap_add: &mut [f32],
243    normalization: &mut [f32],
244) {
245    let frame_size = window.len() as f32;
246    for index in 0..window.len() {
247        let sample = frame[index].re / frame_size;
248        let windowed = sample * window[index];
249        overlap_add[start + index] += windowed;
250        normalization[start + index] += window[index] * window[index];
251    }
252}
253
254fn finalize_samples(
255    filtered: &[f32],
256    overlap_add: &[f32],
257    normalization: &[f32],
258    config: &SanitizedDenoiseOptions,
259    quiet_rms: f32,
260) -> Vec<f32> {
261    let blended = (0..filtered.len())
262        .map(|index| {
263            let restored = if normalization[index] > 1e-6 {
264                overlap_add[index] / normalization[index]
265            } else {
266                0.0
267            };
268            (restored * config.wet_mix + filtered[index] * (1.0 - config.wet_mix)).clamp(-1.0, 1.0)
269        })
270        .collect::<Vec<_>>();
271
272    apply_adaptive_noise_gate(&blended, filtered, config, quiet_rms)
273}
274
275fn apply_adaptive_noise_gate(
276    denoised: &[f32],
277    sidechain: &[f32],
278    config: &SanitizedDenoiseOptions,
279    quiet_rms: f32,
280) -> Vec<f32> {
281    let close_threshold = quiet_rms.max(1e-5) * 1.1;
282    let open_threshold = close_threshold * (2.4 + 0.25 * config.noise_reduction.max(0.0));
283    let floor = (config.residual_floor * 0.5).clamp(0.05, 0.35);
284    let attack_coeff = envelope_coeff(config.sample_rate, 0.006);
285    let release_coeff = envelope_coeff(config.sample_rate, 0.08);
286    let mut envelope = 0.0f32;
287
288    denoised
289        .iter()
290        .zip(sidechain.iter())
291        .map(|(&sample, &driver)| {
292            let amplitude = sample.abs().max(driver.abs());
293            envelope = if amplitude > envelope {
294                attack_coeff * envelope + (1.0 - attack_coeff) * amplitude
295            } else {
296                release_coeff * envelope + (1.0 - release_coeff) * amplitude
297            };
298
299            let normalized = ((envelope - close_threshold)
300                / (open_threshold - close_threshold + 1e-6))
301                .clamp(0.0, 1.0);
302            let gain = floor + (1.0 - floor) * smoothstep(normalized);
303            (sample * gain).clamp(-1.0, 1.0)
304        })
305        .collect()
306}
307
308fn envelope_coeff(sample_rate: u32, seconds: f32) -> f32 {
309    if sample_rate == 0 {
310        return 0.0;
311    }
312
313    (-1.0 / (sample_rate as f32 * seconds.max(1e-3))).exp()
314}
315
316fn smoothstep(value: f32) -> f32 {
317    value * value * (3.0 - 2.0 * value)
318}
319
320fn normalize_and_smooth(values: &mut [f32], divisor: f32) {
321    for value in values.iter_mut() {
322        *value /= divisor.max(1.0);
323    }
324    smooth_in_place(values);
325}
326
327fn apply_speech_bandpass(samples: &[f32], config: &SanitizedDenoiseOptions) -> Vec<f32> {
328    let mut output = samples.to_vec();
329    if let Some(mut high_pass) = Biquad::high_pass(config.sample_rate, config.low_hz, 0.707) {
330        for sample in &mut output {
331            *sample = high_pass.process(*sample);
332        }
333    }
334    if let Some(mut low_pass) = Biquad::low_pass(config.sample_rate, config.high_hz, 0.707) {
335        for sample in &mut output {
336            *sample = low_pass.process(*sample);
337        }
338    }
339    output
340}
341
342fn hann_window(frame_size: usize) -> Vec<f32> {
343    if frame_size <= 1 {
344        return vec![1.0; frame_size.max(1)];
345    }
346
347    (0..frame_size)
348        .map(|index| 0.5 - 0.5 * (2.0 * PI * index as f32 / frame_size as f32).cos())
349        .collect()
350}
351
352fn frame_offsets(sample_count: usize, frame_size: usize, hop_size: usize) -> Vec<usize> {
353    if sample_count <= frame_size {
354        return vec![0];
355    }
356
357    let mut offsets = Vec::new();
358    let mut start = 0usize;
359    while start < sample_count {
360        offsets.push(start);
361        if start + frame_size >= sample_count {
362            break;
363        }
364        start = start.saturating_add(hop_size);
365    }
366    offsets
367}
368
369fn frame_rms(samples: &[f32], start: usize, frame_size: usize) -> f32 {
370    let mut sum = 0.0;
371    for index in 0..frame_size {
372        let sample = samples.get(start + index).copied().unwrap_or(0.0);
373        sum += sample * sample;
374    }
375    (sum / frame_size as f32).sqrt()
376}
377
378fn load_windowed_frame(samples: &[f32], start: usize, window: &[f32], buffer: &mut [Complex32]) {
379    for (index, value) in buffer.iter_mut().enumerate() {
380        let sample = samples.get(start + index).copied().unwrap_or(0.0);
381        *value = Complex32::new(sample * window[index], 0.0);
382    }
383}
384
385fn smooth_in_place(values: &mut [f32]) {
386    if values.len() < 3 {
387        return;
388    }
389
390    let original = values.to_vec();
391    for index in 0..values.len() {
392        let left = index.saturating_sub(1);
393        let right = (index + 1).min(values.len() - 1);
394        let width = (right - left + 1) as f32;
395        values[index] = original[left..=right].iter().copied().sum::<f32>() / width;
396    }
397}
398
399#[derive(Debug, Clone)]
400struct NoiseEstimate {
401    spectrum: Vec<f32>,
402    quiet_rms: f32,
403}
404
405#[derive(Debug, Clone, Copy)]
406struct SanitizedDenoiseOptions {
407    sample_rate: u32,
408    frame_size: usize,
409    hop_size: usize,
410    low_hz: f32,
411    high_hz: f32,
412    noise_estimation_percentile: f32,
413    noise_reduction: f32,
414    residual_floor: f32,
415    wet_mix: f32,
416}
417
418impl SanitizedDenoiseOptions {
419    fn new(sample_rate: u32, options: DenoiseOptions) -> Self {
420        let frame_size = options.frame_size.max(128);
421        let hop_size = options.hop_size.max(1).min(frame_size);
422        let nyquist = sample_rate as f32 * 0.5;
423
424        Self {
425            sample_rate,
426            frame_size,
427            hop_size,
428            low_hz: options.speech_low_hz.max(0.0).min(nyquist * 0.9),
429            high_hz: options
430                .speech_high_hz
431                .max(options.speech_low_hz + 1.0)
432                .min(nyquist * 0.98),
433            noise_estimation_percentile: options.noise_estimation_percentile.clamp(0.05, 0.8),
434            noise_reduction: options.noise_reduction.max(0.0),
435            residual_floor: options.residual_floor.clamp(0.0, 1.0),
436            wet_mix: options.wet_mix.clamp(0.0, 1.0),
437        }
438    }
439}
440
441#[derive(Debug, Clone, Copy)]
442struct Biquad {
443    b0: f32,
444    b1: f32,
445    b2: f32,
446    a1: f32,
447    a2: f32,
448    z1: f32,
449    z2: f32,
450}
451
452impl Biquad {
453    fn high_pass(sample_rate: u32, cutoff_hz: f32, q: f32) -> Option<Self> {
454        Self::from_coefficients(sample_rate, cutoff_hz, q, FilterKind::HighPass)
455    }
456
457    fn low_pass(sample_rate: u32, cutoff_hz: f32, q: f32) -> Option<Self> {
458        Self::from_coefficients(sample_rate, cutoff_hz, q, FilterKind::LowPass)
459    }
460
461    fn from_coefficients(
462        sample_rate: u32,
463        cutoff_hz: f32,
464        q: f32,
465        kind: FilterKind,
466    ) -> Option<Self> {
467        if sample_rate == 0 || cutoff_hz <= 0.0 || cutoff_hz >= sample_rate as f32 * 0.5 {
468            return None;
469        }
470
471        let omega = 2.0 * PI * cutoff_hz / sample_rate as f32;
472        let sin_omega = omega.sin();
473        let cos_omega = omega.cos();
474        let alpha = sin_omega / (2.0 * q.max(1e-3));
475        let (b0, b1, b2) = match kind {
476            FilterKind::LowPass => (
477                (1.0 - cos_omega) * 0.5,
478                1.0 - cos_omega,
479                (1.0 - cos_omega) * 0.5,
480            ),
481            FilterKind::HighPass => (
482                (1.0 + cos_omega) * 0.5,
483                -(1.0 + cos_omega),
484                (1.0 + cos_omega) * 0.5,
485            ),
486        };
487        let a0 = 1.0 + alpha;
488
489        Some(Self {
490            b0: b0 / a0,
491            b1: b1 / a0,
492            b2: b2 / a0,
493            a1: (-2.0 * cos_omega) / a0,
494            a2: (1.0 - alpha) / a0,
495            z1: 0.0,
496            z2: 0.0,
497        })
498    }
499
500    fn process(&mut self, input: f32) -> f32 {
501        let output = input * self.b0 + self.z1;
502        self.z1 = input * self.b1 + self.z2 - self.a1 * output;
503        self.z2 = input * self.b2 - self.a2 * output;
504        output
505    }
506}
507
508#[derive(Debug, Clone, Copy)]
509enum FilterKind {
510    LowPass,
511    HighPass,
512}