audiobook-creation-exchange 0.1.0

ACX-compliant audio post-processing: normalisation, limiting, gating, LUFS measurement, and spectral analysis for AI-generated speech audio.
Documentation
//! Wiener spectral subtraction noise reducer.
//!
//! Estimates the noise power spectrum from the first `profile_ms` of audio
//! (typically the head room-tone bookend), then attenuates each FFT bin in
//! every subsequent frame by the Wiener gain:
//!
//! ```text
//! G[k] = max(floor, 1 − α · noise_power[k] / frame_power[k])
//! ```
//!
//! `α` (oversubtraction) is set above 1.0 to compensate for noise estimation
//! error; `floor` prevents total cancellation of low-energy bins (musical noise).
//!
//! The same OLA STFT structure as the de-esser and plosive suppressor is used
//! (50 ms frames, 50 % hop, periodic Hann window).

use rustfft::{FftPlanner, num_complex::Complex};

/// Oversubtraction factor (>1 compensates for noise estimation error).
pub const DEFAULT_OVERSUBTRACTION: f32 = 1.5;
/// Spectral floor — bins are never reduced below this gain (−20 dB ≈ 0.1).
pub const DEFAULT_SPECTRAL_FLOOR: f32 = 0.1;
/// Length of the noise profile window at the start of the signal, in ms.
pub const DEFAULT_PROFILE_MS: u32 = 200;

const WINDOW_MS: usize = 50;

/// Reduce noise using default parameters.
pub fn denoise(samples: &mut [i16], sample_rate: u32) {
    denoise_with_params(
        samples,
        sample_rate,
        DEFAULT_PROFILE_MS,
        DEFAULT_OVERSUBTRACTION,
        DEFAULT_SPECTRAL_FLOOR,
    );
}

/// Reduce noise with explicit parameters.
pub fn denoise_with_params(
    samples: &mut [i16],
    sample_rate: u32,
    profile_ms: u32,
    oversubtraction: f32,
    spectral_floor: f32,
) {
    let window_size = (sample_rate as usize * WINDOW_MS) / 1000;
    if window_size < 4 || samples.is_empty() {
        return;
    }
    let hop = window_size / 2;
    let half = window_size / 2;

    let profile_samples = ((sample_rate as usize * profile_ms as usize) / 1000)
        .min(samples.len())
        .max(window_size);

    let hann = hann_periodic(window_size);

    let mut planner = FftPlanner::<f32>::new();
    let fft_forward = planner.plan_fft_forward(window_size);
    let fft_inverse = planner.plan_fft_inverse(window_size);

    // ── Step 1: build noise power profile from the leading segment ───────────
    // Only include frames whose windows are fully within the profile region —
    // boundary frames that overlap into the speech section would inflate the
    // noise estimate and cause speech-to-noise OLA leakage artifacts.
    let n_profile_frames = profile_samples.div_ceil(hop);
    let mut noise_power = vec![0f32; half];
    let mut profile_count = 0usize;

    for frame_idx in 0..n_profile_frames {
        let start = frame_idx * hop;
        let frame_end = start + window_size;
        if start >= samples.len() || frame_end > profile_samples {
            break;
        }
        let mut buf: Vec<Complex<f32>> = (0..window_size)
            .map(|i| {
                let s = if start + i < samples.len() {
                    samples[start + i] as f32
                } else {
                    0.0
                };
                Complex {
                    re: s * hann[i],
                    im: 0.0,
                }
            })
            .collect();
        fft_forward.process(&mut buf);
        for (k, c) in buf[..half].iter().enumerate() {
            noise_power[k] += c.norm_sqr();
        }
        profile_count += 1;
    }
    if profile_count > 0 {
        for p in noise_power.iter_mut() {
            *p /= profile_count as f32;
        }
    }

    // ── Step 2: OLA processing with Wiener gain ───────────────────────────────
    let n = samples.len();
    let n_frames = n.div_ceil(hop);
    let mut output = vec![0f32; n];
    let mut norm = vec![0f32; n];

    for frame_idx in 0..n_frames {
        let start = frame_idx * hop;
        if start >= n {
            break;
        }

        let mut buf: Vec<Complex<f32>> = (0..window_size)
            .map(|i| {
                let s = if start + i < n {
                    samples[start + i] as f32
                } else {
                    0.0
                };
                Complex {
                    re: s * hann[i],
                    im: 0.0,
                }
            })
            .collect();

        fft_forward.process(&mut buf);

        // Apply Wiener gain to each positive-frequency bin and its mirror.
        for k in 0..half {
            let frame_p = buf[k].norm_sqr().max(f32::EPSILON);
            let gain = (1.0 - oversubtraction * noise_power[k] / frame_p).max(spectral_floor);
            buf[k].re *= gain;
            buf[k].im *= gain;
            let mirror = window_size - k;
            if mirror < window_size && mirror > half {
                buf[mirror].re *= gain;
                buf[mirror].im *= gain;
            }
        }

        fft_inverse.process(&mut buf);

        let scale = 1.0 / window_size as f32;
        for i in 0..window_size {
            let out_idx = start + i;
            if out_idx < n {
                output[out_idx] += buf[i].re * scale;
                norm[out_idx] += hann[i];
            }
        }
    }

    for (i, s) in samples.iter_mut().enumerate() {
        if norm[i] > f32::EPSILON {
            *s = (output[i] / norm[i])
                .round()
                .clamp(i16::MIN as f32, i16::MAX as f32) as i16;
        }
    }
}

fn hann_periodic(n: usize) -> Vec<f32> {
    (0..n)
        .map(|i| 0.5 * (1.0 - (2.0 * std::f32::consts::PI * i as f32 / n as f32).cos()))
        .collect()
}

#[cfg(test)]
mod tests {
    use rand::RngExt;

    use super::*;
    use crate::analyse::rms_db;

    const SR: u32 = 24_000;

    fn pure_tone(freq_hz: f32, amplitude: f32, secs: f32) -> Vec<i16> {
        let n = (SR as f32 * secs) as usize;
        (0..n)
            .map(|i| {
                let v =
                    amplitude * (2.0 * std::f32::consts::PI * freq_hz * i as f32 / SR as f32).sin();
                v.clamp(i16::MIN as f32, i16::MAX as f32) as i16
            })
            .collect()
    }

    fn white_noise(amplitude: f32, n: usize) -> Vec<i16> {
        let mut rng = rand::rng();
        (0..n)
            .map(|_| {
                let v = (rng.random::<f32>() * 2.0 - 1.0) * amplitude;
                v.clamp(i16::MIN as f32, i16::MAX as f32) as i16
            })
            .collect()
    }

    #[test]
    fn empty_input_is_a_no_op() {
        let mut samples: Vec<i16> = Vec::new();
        denoise(&mut samples, SR);
    }

    #[test]
    fn speech_tone_preserved_after_denoising() {
        // 200 ms of noise profile + 800 ms of speech tone.
        let profile_len = (SR as f32 * 0.2) as usize;
        let mut samples = white_noise(200.0, profile_len);
        samples.extend(pure_tone(440.0, 5_000.0, 0.8));

        let speech_start = profile_len;
        let before_rms = rms_db(&samples[speech_start..]);
        denoise(&mut samples, SR);
        let after_rms = rms_db(&samples[speech_start..]);

        // Speech band should lose less than 6 dB.
        assert!(
            (before_rms - after_rms) < 6.0,
            "Speech attenuated too much: before={:.1} after={:.1}",
            before_rms,
            after_rms
        );
    }

    #[test]
    fn noise_is_attenuated_in_steady_state() {
        // 2 seconds of pure white noise, profiled entirely, measured in the central
        // 500 ms — well away from any OLA boundary effects at the buffer edges.
        let n = (SR as usize) * 2;
        let noise = white_noise(500.0, n);
        let mut samples = noise.clone();

        // Profile the whole buffer (2 000 ms).
        denoise_with_params(
            &mut samples,
            SR,
            2_000,
            DEFAULT_OVERSUBTRACTION,
            DEFAULT_SPECTRAL_FLOOR,
        );

        // Measure 750 ms … 1 250 ms — the settled central region.
        let mid_s = (SR as usize * 750) / 1000;
        let mid_e = (SR as usize * 1250) / 1000;
        let before = rms_db(&noise[mid_s..mid_e]);
        let after = rms_db(&samples[mid_s..mid_e]);

        assert!(
            after < before,
            "Noise not attenuated in steady state: before={:.1} after={:.1}",
            before,
            after
        );
    }
}