1use std::cmp::Ordering;
2use std::f32::consts::PI;
3
4use rustfft::num_complex::Complex32;
5use rustfft::{Fft, FftPlanner};
6
7use super::AudioSamples;
8
9#[derive(Debug, Clone, Copy)]
16pub struct DenoiseOptions {
17 pub frame_size: usize,
19 pub hop_size: usize,
21 pub speech_low_hz: f32,
23 pub speech_high_hz: f32,
25 pub noise_estimation_percentile: f32,
27 pub noise_reduction: f32,
29 pub residual_floor: f32,
31 pub wet_mix: f32,
33}
34
35impl Default for DenoiseOptions {
36 fn default() -> Self {
37 Self {
38 frame_size: 1024,
39 hop_size: 256,
40 speech_low_hz: 110.0,
41 speech_high_hz: 5_800.0,
42 noise_estimation_percentile: 0.2,
43 noise_reduction: 1.35,
44 residual_floor: 0.08,
45 wet_mix: 0.9,
46 }
47 }
48}
49
50pub(super) fn denoise_audio_samples(audio: &AudioSamples, options: DenoiseOptions) -> AudioSamples {
51 if audio.is_empty() || audio.sample_rate == 0 {
52 return audio.clone();
53 }
54
55 let config = SanitizedDenoiseOptions::new(audio.sample_rate, options);
56 let filtered = apply_speech_bandpass(&audio.samples, &config);
57 let window = hann_window(config.frame_size);
58 let mut planner = FftPlanner::<f32>::new();
59 let fft = planner.plan_fft_forward(config.frame_size);
60 let ifft = planner.plan_fft_inverse(config.frame_size);
61 let noise_estimate = estimate_noise_profile(&filtered, &config, &window, fft.as_ref());
62 let cleaned = render_denoised_samples(
63 &filtered,
64 &config,
65 &window,
66 fft.as_ref(),
67 ifft.as_ref(),
68 &noise_estimate,
69 );
70
71 AudioSamples::new(cleaned, audio.sample_rate)
72}
73
74fn estimate_noise_profile(
75 samples: &[f32],
76 config: &SanitizedDenoiseOptions,
77 window: &[f32],
78 fft: &dyn Fft<f32>,
79) -> NoiseEstimate {
80 let selected_offsets = select_quiet_frame_offsets(samples, config);
81 let mut profile = vec![0.0; config.frame_size];
82 let mut buffer = vec![Complex32::default(); config.frame_size];
83 let mut quiet_rms_sum = 0.0;
84
85 for &start in &selected_offsets {
86 quiet_rms_sum += frame_rms(samples, start, config.frame_size);
87 load_windowed_frame(samples, start, window, &mut buffer);
88 fft.process(&mut buffer);
89 for (value, spectrum) in profile.iter_mut().zip(&buffer) {
90 *value += spectrum.norm();
91 }
92 }
93
94 let frame_count = selected_offsets.len().max(1) as f32;
95 normalize_and_smooth(&mut profile, frame_count);
96 NoiseEstimate {
97 spectrum: profile,
98 quiet_rms: quiet_rms_sum / frame_count,
99 }
100}
101
102fn select_quiet_frame_offsets(samples: &[f32], config: &SanitizedDenoiseOptions) -> Vec<usize> {
103 let mut ranked: Vec<(usize, f32)> =
104 frame_offsets(samples.len(), config.frame_size, config.hop_size)
105 .into_iter()
106 .map(|start| (start, frame_rms(samples, start, config.frame_size)))
107 .collect();
108 ranked.sort_by(|left, right| left.1.partial_cmp(&right.1).unwrap_or(Ordering::Equal));
109
110 ranked
111 .into_iter()
112 .take(quiet_frame_count(samples, config))
113 .map(|(start, _)| start)
114 .collect()
115}
116
117fn quiet_frame_count(samples: &[f32], config: &SanitizedDenoiseOptions) -> usize {
118 let total_frames = frame_offsets(samples.len(), config.frame_size, config.hop_size)
119 .len()
120 .max(1);
121 ((total_frames as f32 * config.noise_estimation_percentile).ceil() as usize)
122 .clamp(1, total_frames)
123}
124
125fn render_denoised_samples(
126 samples: &[f32],
127 config: &SanitizedDenoiseOptions,
128 window: &[f32],
129 fft: &dyn Fft<f32>,
130 ifft: &dyn Fft<f32>,
131 noise_estimate: &NoiseEstimate,
132) -> Vec<f32> {
133 let offsets = frame_offsets(samples.len(), config.frame_size, config.hop_size);
134 let mut overlap_add = vec![0.0; samples.len() + config.frame_size];
135 let mut normalization = vec![0.0; samples.len() + config.frame_size];
136 let mut buffer = vec![Complex32::default(); config.frame_size];
137 let mut mask = vec![0.0; config.frame_size];
138 let mut adaptive_noise = noise_estimate.spectrum.clone();
139 let mut previous_mask = vec![1.0; config.frame_size];
140
141 for start in offsets {
142 load_windowed_frame(samples, start, window, &mut buffer);
143 fft.process(&mut buffer);
144 update_adaptive_noise_profile(
145 &buffer,
146 &noise_estimate.spectrum,
147 &mut adaptive_noise,
148 frame_rms(samples, start, config.frame_size),
149 noise_estimate.quiet_rms,
150 config,
151 );
152 build_spectral_mask(&buffer, &adaptive_noise, &previous_mask, config, &mut mask);
153 previous_mask.clone_from_slice(&mask);
154 apply_mask(&mut buffer, &mask);
155 ifft.process(&mut buffer);
156 overlap_add_frame(&buffer, start, window, &mut overlap_add, &mut normalization);
157 }
158
159 finalize_samples(
160 samples,
161 &overlap_add,
162 &normalization,
163 config,
164 noise_estimate.quiet_rms,
165 )
166}
167
168fn build_spectral_mask(
169 spectrum: &[Complex32],
170 noise_profile: &[f32],
171 previous_mask: &[f32],
172 config: &SanitizedDenoiseOptions,
173 mask: &mut [f32],
174) {
175 const EPSILON: f32 = 1e-6;
176
177 for index in 0..mask.len() {
178 let magnitude = spectrum[index].norm();
179 let noise = noise_profile[index].max(EPSILON);
180 let power = magnitude * magnitude;
181 let noise_power = noise * noise;
182 let wiener =
183 (1.0 - config.noise_reduction * noise_power / (power + EPSILON)).clamp(0.0, 1.0);
184 let snr_db = 10.0 * ((power + EPSILON) / (noise_power + EPSILON)).log10();
185 let soft_gate = ((snr_db + 3.0) / 12.0).clamp(0.0, 1.0);
186 let harmonic_ratio = magnitude / noise;
187 let harmonic_guard = ((harmonic_ratio - 0.75) / 2.25).clamp(0.0, 1.0);
188 let gain = wiener * (0.2 + 0.8 * soft_gate) * (0.25 + 0.75 * harmonic_guard);
189 let temporal = if gain > previous_mask[index] {
190 previous_mask[index] * 0.25 + gain * 0.75
191 } else {
192 previous_mask[index] * 0.4 + gain * 0.6
193 };
194 mask[index] = config.residual_floor + (1.0 - config.residual_floor) * temporal;
195 }
196 smooth_in_place(mask);
197 smooth_in_place(mask);
198}
199
200fn update_adaptive_noise_profile(
201 spectrum: &[Complex32],
202 base_noise: &[f32],
203 adaptive_noise: &mut [f32],
204 frame_rms: f32,
205 quiet_rms: f32,
206 config: &SanitizedDenoiseOptions,
207) {
208 let quiet_threshold = quiet_rms.max(1e-5) * (1.4 + config.noise_estimation_percentile);
209 let update_rate = if frame_rms <= quiet_threshold {
210 0.18
211 } else {
212 0.035
213 };
214 let growth_limit = if frame_rms <= quiet_threshold {
215 2.0
216 } else {
217 1.15
218 };
219
220 for index in 0..adaptive_noise.len() {
221 let capped_magnitude = spectrum[index]
222 .norm()
223 .min(base_noise[index] * growth_limit + 1e-5);
224 adaptive_noise[index] =
225 adaptive_noise[index] * (1.0 - update_rate) + capped_magnitude * update_rate;
226 adaptive_noise[index] = adaptive_noise[index].max(base_noise[index] * 0.5);
227 }
228
229 smooth_in_place(adaptive_noise);
230}
231
232fn apply_mask(spectrum: &mut [Complex32], mask: &[f32]) {
233 for (bin, value) in spectrum.iter_mut().zip(mask) {
234 *bin *= *value;
235 }
236}
237
238fn overlap_add_frame(
239 frame: &[Complex32],
240 start: usize,
241 window: &[f32],
242 overlap_add: &mut [f32],
243 normalization: &mut [f32],
244) {
245 let frame_size = window.len() as f32;
246 for index in 0..window.len() {
247 let sample = frame[index].re / frame_size;
248 let windowed = sample * window[index];
249 overlap_add[start + index] += windowed;
250 normalization[start + index] += window[index] * window[index];
251 }
252}
253
254fn finalize_samples(
255 filtered: &[f32],
256 overlap_add: &[f32],
257 normalization: &[f32],
258 config: &SanitizedDenoiseOptions,
259 quiet_rms: f32,
260) -> Vec<f32> {
261 let blended = (0..filtered.len())
262 .map(|index| {
263 let restored = if normalization[index] > 1e-6 {
264 overlap_add[index] / normalization[index]
265 } else {
266 0.0
267 };
268 (restored * config.wet_mix + filtered[index] * (1.0 - config.wet_mix)).clamp(-1.0, 1.0)
269 })
270 .collect::<Vec<_>>();
271
272 apply_adaptive_noise_gate(&blended, filtered, config, quiet_rms)
273}
274
275fn apply_adaptive_noise_gate(
276 denoised: &[f32],
277 sidechain: &[f32],
278 config: &SanitizedDenoiseOptions,
279 quiet_rms: f32,
280) -> Vec<f32> {
281 let close_threshold = quiet_rms.max(1e-5) * 1.1;
282 let open_threshold = close_threshold * (2.4 + 0.25 * config.noise_reduction.max(0.0));
283 let floor = (config.residual_floor * 0.5).clamp(0.05, 0.35);
284 let attack_coeff = envelope_coeff(config.sample_rate, 0.006);
285 let release_coeff = envelope_coeff(config.sample_rate, 0.08);
286 let mut envelope = 0.0f32;
287
288 denoised
289 .iter()
290 .zip(sidechain.iter())
291 .map(|(&sample, &driver)| {
292 let amplitude = sample.abs().max(driver.abs());
293 envelope = if amplitude > envelope {
294 attack_coeff * envelope + (1.0 - attack_coeff) * amplitude
295 } else {
296 release_coeff * envelope + (1.0 - release_coeff) * amplitude
297 };
298
299 let normalized = ((envelope - close_threshold)
300 / (open_threshold - close_threshold + 1e-6))
301 .clamp(0.0, 1.0);
302 let gain = floor + (1.0 - floor) * smoothstep(normalized);
303 (sample * gain).clamp(-1.0, 1.0)
304 })
305 .collect()
306}
307
308fn envelope_coeff(sample_rate: u32, seconds: f32) -> f32 {
309 if sample_rate == 0 {
310 return 0.0;
311 }
312
313 (-1.0 / (sample_rate as f32 * seconds.max(1e-3))).exp()
314}
315
316fn smoothstep(value: f32) -> f32 {
317 value * value * (3.0 - 2.0 * value)
318}
319
320fn normalize_and_smooth(values: &mut [f32], divisor: f32) {
321 for value in values.iter_mut() {
322 *value /= divisor.max(1.0);
323 }
324 smooth_in_place(values);
325}
326
327fn apply_speech_bandpass(samples: &[f32], config: &SanitizedDenoiseOptions) -> Vec<f32> {
328 let mut output = samples.to_vec();
329 if let Some(mut high_pass) = Biquad::high_pass(config.sample_rate, config.low_hz, 0.707) {
330 for sample in &mut output {
331 *sample = high_pass.process(*sample);
332 }
333 }
334 if let Some(mut low_pass) = Biquad::low_pass(config.sample_rate, config.high_hz, 0.707) {
335 for sample in &mut output {
336 *sample = low_pass.process(*sample);
337 }
338 }
339 output
340}
341
342fn hann_window(frame_size: usize) -> Vec<f32> {
343 if frame_size <= 1 {
344 return vec![1.0; frame_size.max(1)];
345 }
346
347 (0..frame_size)
348 .map(|index| 0.5 - 0.5 * (2.0 * PI * index as f32 / frame_size as f32).cos())
349 .collect()
350}
351
352fn frame_offsets(sample_count: usize, frame_size: usize, hop_size: usize) -> Vec<usize> {
353 if sample_count <= frame_size {
354 return vec![0];
355 }
356
357 let mut offsets = Vec::new();
358 let mut start = 0usize;
359 while start < sample_count {
360 offsets.push(start);
361 if start + frame_size >= sample_count {
362 break;
363 }
364 start = start.saturating_add(hop_size);
365 }
366 offsets
367}
368
369fn frame_rms(samples: &[f32], start: usize, frame_size: usize) -> f32 {
370 let mut sum = 0.0;
371 for index in 0..frame_size {
372 let sample = samples.get(start + index).copied().unwrap_or(0.0);
373 sum += sample * sample;
374 }
375 (sum / frame_size as f32).sqrt()
376}
377
378fn load_windowed_frame(samples: &[f32], start: usize, window: &[f32], buffer: &mut [Complex32]) {
379 for (index, value) in buffer.iter_mut().enumerate() {
380 let sample = samples.get(start + index).copied().unwrap_or(0.0);
381 *value = Complex32::new(sample * window[index], 0.0);
382 }
383}
384
385fn smooth_in_place(values: &mut [f32]) {
386 if values.len() < 3 {
387 return;
388 }
389
390 let original = values.to_vec();
391 for index in 0..values.len() {
392 let left = index.saturating_sub(1);
393 let right = (index + 1).min(values.len() - 1);
394 let width = (right - left + 1) as f32;
395 values[index] = original[left..=right].iter().copied().sum::<f32>() / width;
396 }
397}
398
399#[derive(Debug, Clone)]
400struct NoiseEstimate {
401 spectrum: Vec<f32>,
402 quiet_rms: f32,
403}
404
405#[derive(Debug, Clone, Copy)]
406struct SanitizedDenoiseOptions {
407 sample_rate: u32,
408 frame_size: usize,
409 hop_size: usize,
410 low_hz: f32,
411 high_hz: f32,
412 noise_estimation_percentile: f32,
413 noise_reduction: f32,
414 residual_floor: f32,
415 wet_mix: f32,
416}
417
418impl SanitizedDenoiseOptions {
419 fn new(sample_rate: u32, options: DenoiseOptions) -> Self {
420 let frame_size = options.frame_size.max(128);
421 let hop_size = options.hop_size.max(1).min(frame_size);
422 let nyquist = sample_rate as f32 * 0.5;
423
424 Self {
425 sample_rate,
426 frame_size,
427 hop_size,
428 low_hz: options.speech_low_hz.max(0.0).min(nyquist * 0.9),
429 high_hz: options
430 .speech_high_hz
431 .max(options.speech_low_hz + 1.0)
432 .min(nyquist * 0.98),
433 noise_estimation_percentile: options.noise_estimation_percentile.clamp(0.05, 0.8),
434 noise_reduction: options.noise_reduction.max(0.0),
435 residual_floor: options.residual_floor.clamp(0.0, 1.0),
436 wet_mix: options.wet_mix.clamp(0.0, 1.0),
437 }
438 }
439}
440
441#[derive(Debug, Clone, Copy)]
442struct Biquad {
443 b0: f32,
444 b1: f32,
445 b2: f32,
446 a1: f32,
447 a2: f32,
448 z1: f32,
449 z2: f32,
450}
451
452impl Biquad {
453 fn high_pass(sample_rate: u32, cutoff_hz: f32, q: f32) -> Option<Self> {
454 Self::from_coefficients(sample_rate, cutoff_hz, q, FilterKind::HighPass)
455 }
456
457 fn low_pass(sample_rate: u32, cutoff_hz: f32, q: f32) -> Option<Self> {
458 Self::from_coefficients(sample_rate, cutoff_hz, q, FilterKind::LowPass)
459 }
460
461 fn from_coefficients(
462 sample_rate: u32,
463 cutoff_hz: f32,
464 q: f32,
465 kind: FilterKind,
466 ) -> Option<Self> {
467 if sample_rate == 0 || cutoff_hz <= 0.0 || cutoff_hz >= sample_rate as f32 * 0.5 {
468 return None;
469 }
470
471 let omega = 2.0 * PI * cutoff_hz / sample_rate as f32;
472 let sin_omega = omega.sin();
473 let cos_omega = omega.cos();
474 let alpha = sin_omega / (2.0 * q.max(1e-3));
475 let (b0, b1, b2) = match kind {
476 FilterKind::LowPass => (
477 (1.0 - cos_omega) * 0.5,
478 1.0 - cos_omega,
479 (1.0 - cos_omega) * 0.5,
480 ),
481 FilterKind::HighPass => (
482 (1.0 + cos_omega) * 0.5,
483 -(1.0 + cos_omega),
484 (1.0 + cos_omega) * 0.5,
485 ),
486 };
487 let a0 = 1.0 + alpha;
488
489 Some(Self {
490 b0: b0 / a0,
491 b1: b1 / a0,
492 b2: b2 / a0,
493 a1: (-2.0 * cos_omega) / a0,
494 a2: (1.0 - alpha) / a0,
495 z1: 0.0,
496 z2: 0.0,
497 })
498 }
499
500 fn process(&mut self, input: f32) -> f32 {
501 let output = input * self.b0 + self.z1;
502 self.z1 = input * self.b1 + self.z2 - self.a1 * output;
503 self.z2 = input * self.b2 - self.a2 * output;
504 output
505 }
506}
507
508#[derive(Debug, Clone, Copy)]
509enum FilterKind {
510 LowPass,
511 HighPass,
512}