neco-minphase 0.1.0

Minimum-phase spectrum and overlap-add convolution kernels
Documentation
use neco_complex::Complex;
use neco_stft::{DspFloat, FftError};

#[derive(Debug, Clone, PartialEq, Eq)]
pub enum MinPhaseError {
    InvalidGainCurveLen { expected: usize, got: usize },
    Fft(FftError),
}

impl core::fmt::Display for MinPhaseError {
    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
        match self {
            Self::InvalidGainCurveLen { expected, got } => {
                write!(f, "wrong gain curve length: expected {expected}, got {got}")
            }
            Self::Fft(err) => err.fmt(f),
        }
    }
}

impl std::error::Error for MinPhaseError {}

impl From<FftError> for MinPhaseError {
    fn from(value: FftError) -> Self {
        Self::Fft(value)
    }
}

pub fn compute_min_phase_spectrum<T: DspFloat>(
    gain_curve: &[T],
    fft_size: usize,
) -> Result<Vec<Complex<T>>, MinPhaseError> {
    let num_bins = fft_size / 2 + 1;
    if gain_curve.len() != num_bins {
        return Err(MinPhaseError::InvalidGainCurveLen {
            expected: num_bins,
            got: gain_curve.len(),
        });
    }

    let epsilon = T::from_f64(1e-20);
    let two = T::from_f64(2.0);

    T::with_fft_planner(|planner| {
        let fft_fwd = planner.plan_fft_forward(fft_size);
        let fft_inv = planner.plan_fft_inverse(fft_size);
        let scale = T::one() / T::from_usize(fft_size);

        let mut log_spectrum: Vec<Complex<T>> = gain_curve
            .iter()
            .map(|&gain| Complex::new(gain.max(epsilon).ln(), T::zero()))
            .collect();

        let mut cepstrum = fft_inv.make_output_vec();
        fft_inv.process(&mut log_spectrum, &mut cepstrum)?;
        for value in &mut cepstrum {
            *value *= scale;
        }

        let mut cepstrum_min = vec![T::zero(); fft_size];
        cepstrum_min[0] = cepstrum[0];
        for i in 1..fft_size / 2 {
            cepstrum_min[i] = two * cepstrum[i];
        }
        cepstrum_min[fft_size / 2] = cepstrum[fft_size / 2];

        let mut min_log_spectrum = fft_fwd.make_output_vec();
        fft_fwd.process(&mut cepstrum_min, &mut min_log_spectrum)?;

        for bin in &mut min_log_spectrum {
            let amplitude = bin.re.exp();
            let phase = bin.im;
            *bin = Complex::new(amplitude * phase.cos(), amplitude * phase.sin());
        }

        Ok(min_log_spectrum)
    })
}

pub fn compute_min_phase_ir<T: DspFloat>(
    gain_curve: &[T],
    fft_size: usize,
) -> Result<Vec<T>, MinPhaseError> {
    let mut min_spectrum = compute_min_phase_spectrum(gain_curve, fft_size)?;
    T::with_fft_planner(|planner| {
        let fft_inv = planner.plan_fft_inverse(fft_size);
        let scale = T::one() / T::from_usize(fft_size);
        let mut ir = fft_inv.make_output_vec();
        fft_inv.process(&mut min_spectrum, &mut ir)?;
        for sample in &mut ir {
            *sample *= scale;
        }
        Ok(ir)
    })
}

pub fn convolve_ola<T: DspFloat>(input: &[T], ir: &[T]) -> Result<Vec<T>, MinPhaseError> {
    let n = input.len();
    let m = ir.len();
    if n == 0 || m == 0 {
        return Ok(vec![T::zero(); n]);
    }

    let block_size = m.next_power_of_two();
    let conv_size = (block_size + m - 1).next_power_of_two();

    T::with_fft_planner(|planner| {
        let fft_fwd = planner.plan_fft_forward(conv_size);
        let fft_inv = planner.plan_fft_inverse(conv_size);
        let scale = T::one() / T::from_usize(conv_size);

        let mut ir_padded = vec![T::zero(); conv_size];
        ir_padded[..m].copy_from_slice(ir);
        let mut ir_spectrum = fft_fwd.make_output_vec();
        fft_fwd.process(&mut ir_padded, &mut ir_spectrum)?;

        let mut output = vec![T::zero(); n];
        let mut pos = 0usize;
        let mut block = vec![T::zero(); conv_size];
        let mut block_spectrum = fft_fwd.make_output_vec();
        let mut result = fft_inv.make_output_vec();

        while pos < n {
            let end = (pos + block_size).min(n);
            block.fill(T::zero());
            block[..end - pos].copy_from_slice(&input[pos..end]);

            fft_fwd.process(&mut block, &mut block_spectrum)?;
            for (lhs, rhs) in block_spectrum.iter_mut().zip(ir_spectrum.iter()) {
                let re = lhs.re * rhs.re - lhs.im * rhs.im;
                let im = lhs.re * rhs.im + lhs.im * rhs.re;
                lhs.re = re;
                lhs.im = im;
            }

            fft_inv.process(&mut block_spectrum, &mut result)?;
            for i in 0..conv_size {
                if pos + i < n {
                    output[pos + i] += result[i] * scale;
                }
            }

            pos += block_size;
        }

        Ok(output)
    })
}

pub fn compute_blend_curve(
    transient_map: &[f64],
    lookahead_samples: usize,
    smooth_samples: usize,
    threshold: f64,
) -> Vec<f64> {
    let n = transient_map.len();
    let mut raw_blend = vec![0.0; n];

    for (i, &value) in transient_map.iter().enumerate() {
        if value > threshold {
            let start = i.saturating_sub(lookahead_samples);
            let end = (i + lookahead_samples / 2).min(n);
            for item in &mut raw_blend[start..end] {
                *item = 1.0;
            }
        }
    }

    if smooth_samples < 2 {
        return raw_blend;
    }

    let half = smooth_samples / 2;
    let mut smoothed = vec![0.0; n];
    let mut running_sum = 0.0;
    for value in &raw_blend[..half.min(n)] {
        running_sum += *value;
    }

    for (i, out) in smoothed.iter_mut().enumerate() {
        let right = i + half;
        if right < n {
            running_sum += raw_blend[right];
        }
        if i > half + 1 {
            let left = i - half - 1;
            running_sum -= raw_blend[left];
        }
        let actual_window = (i + half + 1).min(n) - i.saturating_sub(half);
        *out = (running_sum / actual_window as f64).clamp(0.0, 1.0);
    }

    smoothed
}

#[cfg(test)]
mod tests {
    use std::f64::consts::PI;

    use neco_complex::Complex;
    use neco_stft::{cast_vec, DspFloat};

    use super::*;

    fn forward_spectrum<T: DspFloat>(input: &[T], fft_size: usize) -> Vec<Complex<T>> {
        T::with_fft_planner(|planner| {
            let fft = planner.plan_fft_forward(fft_size);
            let mut buffer = input.to_vec();
            let mut spectrum = fft.make_output_vec();
            fft.process(&mut buffer, &mut spectrum)
                .expect("fft buffers from planner");
            spectrum
        })
    }

    #[test]
    fn min_phase_rejects_wrong_gain_curve_len() {
        let err = compute_min_phase_spectrum(&[1.0f64, 2.0], 8).expect_err("invalid len");
        assert_eq!(
            err,
            MinPhaseError::InvalidGainCurveLen {
                expected: 5,
                got: 2,
            }
        );
    }

    #[test]
    fn min_phase_ir_has_correct_magnitude() {
        let fft_size = 4096;
        let num_bins = fft_size / 2 + 1;
        let sample_rate = 48000.0;
        let bin_freq = sample_rate / fft_size as f64;

        let gain_curve: Vec<f64> = (0..num_bins)
            .map(|i| {
                let f = i as f64 * bin_freq;
                let a = 10.0f64.powf(6.0 / 20.0);
                let bw = 1000.0 / 2.0;
                let x = (f - 1000.0) / (bw / 2.0);
                1.0 + (a - 1.0) / (1.0 + x * x)
            })
            .collect();

        let ir = compute_min_phase_ir(&gain_curve, fft_size).expect("min phase ir");
        let spectrum = forward_spectrum(&ir, fft_size);

        let max_err_db = (1..num_bins - 1)
            .filter_map(|i| {
                let actual_mag =
                    (spectrum[i].re * spectrum[i].re + spectrum[i].im * spectrum[i].im).sqrt();
                let expected_mag = gain_curve[i];
                (expected_mag > 0.01).then(|| (20.0 * (actual_mag / expected_mag).log10()).abs())
            })
            .fold(0.0, f64::max);

        assert!(max_err_db < 0.01, "magnitude error: {max_err_db:.4}dB");
    }

    #[test]
    fn min_phase_ir_is_causal() {
        let fft_size = 4096;
        let num_bins = fft_size / 2 + 1;
        let sample_rate = 48000.0;
        let bin_freq = sample_rate / fft_size as f64;

        let gain_curve: Vec<f64> = (0..num_bins)
            .map(|i| {
                let f = i as f64 * bin_freq;
                let a = 10.0f64.powf(6.0 / 20.0);
                let bw = 1000.0 / 2.0;
                let x = (f - 1000.0) / (bw / 2.0);
                1.0 + (a - 1.0) / (1.0 + x * x)
            })
            .collect();

        let ir = compute_min_phase_ir(&gain_curve, fft_size).expect("min phase ir");
        let quarter = fft_size / 4;
        let energy_front: f64 = ir[..quarter].iter().map(|x| x * x).sum();
        let energy_back: f64 = ir[3 * quarter..].iter().map(|x| x * x).sum();
        assert!(energy_front > energy_back * 100.0);
    }

    #[test]
    fn convolve_ola_identity() {
        let n = 8192;
        let input: Vec<f64> = (0..n)
            .map(|i| (2.0 * PI * 440.0 * i as f64 / 48000.0).sin())
            .collect();

        let mut ir = vec![0.0; 256];
        ir[0] = 1.0;
        let output = convolve_ola(&input, &ir).expect("convolve");
        let max_err = output
            .iter()
            .zip(input.iter())
            .map(|(&o, &i)| (o - i).abs())
            .fold(0.0, f64::max);
        assert!(max_err < 1e-10, "identity error: {max_err:.2e}");
    }

    #[test]
    fn blend_curve_stays_in_range() {
        let transient_map = vec![0.0, 0.2, 0.9, 0.8, 0.1, 0.0];
        let blend = compute_blend_curve(&transient_map, 2, 4, 0.3);
        assert_eq!(blend.len(), transient_map.len());
        assert!(blend.iter().all(|&value| (0.0..=1.0).contains(&value)));
        assert!(blend[0] > 0.0);
        assert!(blend[2] >= blend[5]);
    }

    #[test]
    fn min_phase_ir_f32_has_reasonable_magnitude() {
        let fft_size = 4096;
        let num_bins = fft_size / 2 + 1;
        let sample_rate = 48000.0;
        let bin_freq = sample_rate / fft_size as f64;

        let gain_curve_f64: Vec<f64> = (0..num_bins)
            .map(|i| {
                let f = i as f64 * bin_freq;
                let a = 10.0f64.powf(6.0 / 20.0);
                let bw = 1000.0 / 2.0;
                let x = (f - 1000.0) / (bw / 2.0);
                1.0 + (a - 1.0) / (1.0 + x * x)
            })
            .collect();
        let gain_curve_f32: Vec<f32> = cast_vec(&gain_curve_f64);

        let ir = compute_min_phase_ir(&gain_curve_f32, fft_size).expect("min phase ir");
        let spectrum = forward_spectrum(&ir, fft_size);

        let max_err_db = (1..num_bins - 1)
            .filter_map(|i| {
                let actual_mag =
                    (spectrum[i].re * spectrum[i].re + spectrum[i].im * spectrum[i].im).sqrt();
                let expected_mag = gain_curve_f32[i];
                (expected_mag > 0.01).then(|| (20.0f32 * (actual_mag / expected_mag).log10()).abs())
            })
            .fold(0.0, f32::max);

        assert!(max_err_db < 0.05, "magnitude error: {max_err_db:.4}dB");
    }

    #[test]
    fn min_phase_ir_non_power_of_two_has_reasonable_magnitude() {
        let fft_size = 1535;
        let num_bins = fft_size / 2 + 1;
        let sample_rate = 48000.0;
        let bin_freq = sample_rate / fft_size as f64;

        let gain_curve: Vec<f64> = (0..num_bins)
            .map(|i| {
                let f = i as f64 * bin_freq;
                let peak = 10.0f64.powf(4.0 / 20.0);
                let width = 800.0 / 2.0;
                let x = (f - 1800.0) / (width / 2.0);
                1.0 + (peak - 1.0) / (1.0 + x * x)
            })
            .collect();

        let ir = compute_min_phase_ir(&gain_curve, fft_size).expect("min phase ir");
        let spectrum = forward_spectrum(&ir, fft_size);

        let max_err_db = (1..num_bins - 1)
            .filter_map(|i| {
                let actual_mag =
                    (spectrum[i].re * spectrum[i].re + spectrum[i].im * spectrum[i].im).sqrt();
                let expected_mag = gain_curve[i];
                (expected_mag > 0.01).then(|| (20.0 * (actual_mag / expected_mag).log10()).abs())
            })
            .fold(0.0, f64::max);

        assert!(max_err_db < 0.03, "magnitude error: {max_err_db:.4}dB");
    }
}