Skip to main content

neco_minphase/
lib.rs

1use neco_complex::Complex;
2use neco_stft::{DspFloat, FftError};
3
4#[derive(Debug, Clone, PartialEq, Eq)]
5pub enum MinPhaseError {
6    InvalidGainCurveLen { expected: usize, got: usize },
7    Fft(FftError),
8}
9
10impl core::fmt::Display for MinPhaseError {
11    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
12        match self {
13            Self::InvalidGainCurveLen { expected, got } => {
14                write!(f, "wrong gain curve length: expected {expected}, got {got}")
15            }
16            Self::Fft(err) => err.fmt(f),
17        }
18    }
19}
20
21impl std::error::Error for MinPhaseError {}
22
23impl From<FftError> for MinPhaseError {
24    fn from(value: FftError) -> Self {
25        Self::Fft(value)
26    }
27}
28
29pub fn compute_min_phase_spectrum<T: DspFloat>(
30    gain_curve: &[T],
31    fft_size: usize,
32) -> Result<Vec<Complex<T>>, MinPhaseError> {
33    let num_bins = fft_size / 2 + 1;
34    if gain_curve.len() != num_bins {
35        return Err(MinPhaseError::InvalidGainCurveLen {
36            expected: num_bins,
37            got: gain_curve.len(),
38        });
39    }
40
41    let epsilon = T::from_f64(1e-20);
42    let two = T::from_f64(2.0);
43
44    T::with_fft_planner(|planner| {
45        let fft_fwd = planner.plan_fft_forward(fft_size);
46        let fft_inv = planner.plan_fft_inverse(fft_size);
47        let scale = T::one() / T::from_usize(fft_size);
48
49        let mut log_spectrum: Vec<Complex<T>> = gain_curve
50            .iter()
51            .map(|&gain| Complex::new(gain.max(epsilon).ln(), T::zero()))
52            .collect();
53
54        let mut cepstrum = fft_inv.make_output_vec();
55        fft_inv.process(&mut log_spectrum, &mut cepstrum)?;
56        for value in &mut cepstrum {
57            *value *= scale;
58        }
59
60        let mut cepstrum_min = vec![T::zero(); fft_size];
61        cepstrum_min[0] = cepstrum[0];
62        for i in 1..fft_size / 2 {
63            cepstrum_min[i] = two * cepstrum[i];
64        }
65        cepstrum_min[fft_size / 2] = cepstrum[fft_size / 2];
66
67        let mut min_log_spectrum = fft_fwd.make_output_vec();
68        fft_fwd.process(&mut cepstrum_min, &mut min_log_spectrum)?;
69
70        for bin in &mut min_log_spectrum {
71            let amplitude = bin.re.exp();
72            let phase = bin.im;
73            *bin = Complex::new(amplitude * phase.cos(), amplitude * phase.sin());
74        }
75
76        Ok(min_log_spectrum)
77    })
78}
79
80pub fn compute_min_phase_ir<T: DspFloat>(
81    gain_curve: &[T],
82    fft_size: usize,
83) -> Result<Vec<T>, MinPhaseError> {
84    let mut min_spectrum = compute_min_phase_spectrum(gain_curve, fft_size)?;
85    T::with_fft_planner(|planner| {
86        let fft_inv = planner.plan_fft_inverse(fft_size);
87        let scale = T::one() / T::from_usize(fft_size);
88        let mut ir = fft_inv.make_output_vec();
89        fft_inv.process(&mut min_spectrum, &mut ir)?;
90        for sample in &mut ir {
91            *sample *= scale;
92        }
93        Ok(ir)
94    })
95}
96
97pub fn convolve_ola<T: DspFloat>(input: &[T], ir: &[T]) -> Result<Vec<T>, MinPhaseError> {
98    let n = input.len();
99    let m = ir.len();
100    if n == 0 || m == 0 {
101        return Ok(vec![T::zero(); n]);
102    }
103
104    let block_size = m.next_power_of_two();
105    let conv_size = (block_size + m - 1).next_power_of_two();
106
107    T::with_fft_planner(|planner| {
108        let fft_fwd = planner.plan_fft_forward(conv_size);
109        let fft_inv = planner.plan_fft_inverse(conv_size);
110        let scale = T::one() / T::from_usize(conv_size);
111
112        let mut ir_padded = vec![T::zero(); conv_size];
113        ir_padded[..m].copy_from_slice(ir);
114        let mut ir_spectrum = fft_fwd.make_output_vec();
115        fft_fwd.process(&mut ir_padded, &mut ir_spectrum)?;
116
117        let mut output = vec![T::zero(); n];
118        let mut pos = 0usize;
119        let mut block = vec![T::zero(); conv_size];
120        let mut block_spectrum = fft_fwd.make_output_vec();
121        let mut result = fft_inv.make_output_vec();
122
123        while pos < n {
124            let end = (pos + block_size).min(n);
125            block.fill(T::zero());
126            block[..end - pos].copy_from_slice(&input[pos..end]);
127
128            fft_fwd.process(&mut block, &mut block_spectrum)?;
129            for (lhs, rhs) in block_spectrum.iter_mut().zip(ir_spectrum.iter()) {
130                let re = lhs.re * rhs.re - lhs.im * rhs.im;
131                let im = lhs.re * rhs.im + lhs.im * rhs.re;
132                lhs.re = re;
133                lhs.im = im;
134            }
135
136            fft_inv.process(&mut block_spectrum, &mut result)?;
137            for i in 0..conv_size {
138                if pos + i < n {
139                    output[pos + i] += result[i] * scale;
140                }
141            }
142
143            pos += block_size;
144        }
145
146        Ok(output)
147    })
148}
149
150pub fn compute_blend_curve(
151    transient_map: &[f64],
152    lookahead_samples: usize,
153    smooth_samples: usize,
154    threshold: f64,
155) -> Vec<f64> {
156    let n = transient_map.len();
157    let mut raw_blend = vec![0.0; n];
158
159    for (i, &value) in transient_map.iter().enumerate() {
160        if value > threshold {
161            let start = i.saturating_sub(lookahead_samples);
162            let end = (i + lookahead_samples / 2).min(n);
163            for item in &mut raw_blend[start..end] {
164                *item = 1.0;
165            }
166        }
167    }
168
169    if smooth_samples < 2 {
170        return raw_blend;
171    }
172
173    let half = smooth_samples / 2;
174    let mut smoothed = vec![0.0; n];
175    let mut running_sum = 0.0;
176    for value in &raw_blend[..half.min(n)] {
177        running_sum += *value;
178    }
179
180    for (i, out) in smoothed.iter_mut().enumerate() {
181        let right = i + half;
182        if right < n {
183            running_sum += raw_blend[right];
184        }
185        if i > half + 1 {
186            let left = i - half - 1;
187            running_sum -= raw_blend[left];
188        }
189        let actual_window = (i + half + 1).min(n) - i.saturating_sub(half);
190        *out = (running_sum / actual_window as f64).clamp(0.0, 1.0);
191    }
192
193    smoothed
194}
195
196#[cfg(test)]
197mod tests {
198    use std::f64::consts::PI;
199
200    use neco_complex::Complex;
201    use neco_stft::{cast_vec, DspFloat};
202
203    use super::*;
204
205    fn forward_spectrum<T: DspFloat>(input: &[T], fft_size: usize) -> Vec<Complex<T>> {
206        T::with_fft_planner(|planner| {
207            let fft = planner.plan_fft_forward(fft_size);
208            let mut buffer = input.to_vec();
209            let mut spectrum = fft.make_output_vec();
210            fft.process(&mut buffer, &mut spectrum)
211                .expect("fft buffers from planner");
212            spectrum
213        })
214    }
215
216    #[test]
217    fn min_phase_rejects_wrong_gain_curve_len() {
218        let err = compute_min_phase_spectrum(&[1.0f64, 2.0], 8).expect_err("invalid len");
219        assert_eq!(
220            err,
221            MinPhaseError::InvalidGainCurveLen {
222                expected: 5,
223                got: 2,
224            }
225        );
226    }
227
228    #[test]
229    fn min_phase_ir_has_correct_magnitude() {
230        let fft_size = 4096;
231        let num_bins = fft_size / 2 + 1;
232        let sample_rate = 48000.0;
233        let bin_freq = sample_rate / fft_size as f64;
234
235        let gain_curve: Vec<f64> = (0..num_bins)
236            .map(|i| {
237                let f = i as f64 * bin_freq;
238                let a = 10.0f64.powf(6.0 / 20.0);
239                let bw = 1000.0 / 2.0;
240                let x = (f - 1000.0) / (bw / 2.0);
241                1.0 + (a - 1.0) / (1.0 + x * x)
242            })
243            .collect();
244
245        let ir = compute_min_phase_ir(&gain_curve, fft_size).expect("min phase ir");
246        let spectrum = forward_spectrum(&ir, fft_size);
247
248        let max_err_db = (1..num_bins - 1)
249            .filter_map(|i| {
250                let actual_mag =
251                    (spectrum[i].re * spectrum[i].re + spectrum[i].im * spectrum[i].im).sqrt();
252                let expected_mag = gain_curve[i];
253                (expected_mag > 0.01).then(|| (20.0 * (actual_mag / expected_mag).log10()).abs())
254            })
255            .fold(0.0, f64::max);
256
257        assert!(max_err_db < 0.01, "magnitude error: {max_err_db:.4}dB");
258    }
259
260    #[test]
261    fn min_phase_ir_is_causal() {
262        let fft_size = 4096;
263        let num_bins = fft_size / 2 + 1;
264        let sample_rate = 48000.0;
265        let bin_freq = sample_rate / fft_size as f64;
266
267        let gain_curve: Vec<f64> = (0..num_bins)
268            .map(|i| {
269                let f = i as f64 * bin_freq;
270                let a = 10.0f64.powf(6.0 / 20.0);
271                let bw = 1000.0 / 2.0;
272                let x = (f - 1000.0) / (bw / 2.0);
273                1.0 + (a - 1.0) / (1.0 + x * x)
274            })
275            .collect();
276
277        let ir = compute_min_phase_ir(&gain_curve, fft_size).expect("min phase ir");
278        let quarter = fft_size / 4;
279        let energy_front: f64 = ir[..quarter].iter().map(|x| x * x).sum();
280        let energy_back: f64 = ir[3 * quarter..].iter().map(|x| x * x).sum();
281        assert!(energy_front > energy_back * 100.0);
282    }
283
284    #[test]
285    fn convolve_ola_identity() {
286        let n = 8192;
287        let input: Vec<f64> = (0..n)
288            .map(|i| (2.0 * PI * 440.0 * i as f64 / 48000.0).sin())
289            .collect();
290
291        let mut ir = vec![0.0; 256];
292        ir[0] = 1.0;
293        let output = convolve_ola(&input, &ir).expect("convolve");
294        let max_err = output
295            .iter()
296            .zip(input.iter())
297            .map(|(&o, &i)| (o - i).abs())
298            .fold(0.0, f64::max);
299        assert!(max_err < 1e-10, "identity error: {max_err:.2e}");
300    }
301
302    #[test]
303    fn blend_curve_stays_in_range() {
304        let transient_map = vec![0.0, 0.2, 0.9, 0.8, 0.1, 0.0];
305        let blend = compute_blend_curve(&transient_map, 2, 4, 0.3);
306        assert_eq!(blend.len(), transient_map.len());
307        assert!(blend.iter().all(|&value| (0.0..=1.0).contains(&value)));
308        assert!(blend[0] > 0.0);
309        assert!(blend[2] >= blend[5]);
310    }
311
312    #[test]
313    fn min_phase_ir_f32_has_reasonable_magnitude() {
314        let fft_size = 4096;
315        let num_bins = fft_size / 2 + 1;
316        let sample_rate = 48000.0;
317        let bin_freq = sample_rate / fft_size as f64;
318
319        let gain_curve_f64: Vec<f64> = (0..num_bins)
320            .map(|i| {
321                let f = i as f64 * bin_freq;
322                let a = 10.0f64.powf(6.0 / 20.0);
323                let bw = 1000.0 / 2.0;
324                let x = (f - 1000.0) / (bw / 2.0);
325                1.0 + (a - 1.0) / (1.0 + x * x)
326            })
327            .collect();
328        let gain_curve_f32: Vec<f32> = cast_vec(&gain_curve_f64);
329
330        let ir = compute_min_phase_ir(&gain_curve_f32, fft_size).expect("min phase ir");
331        let spectrum = forward_spectrum(&ir, fft_size);
332
333        let max_err_db = (1..num_bins - 1)
334            .filter_map(|i| {
335                let actual_mag =
336                    (spectrum[i].re * spectrum[i].re + spectrum[i].im * spectrum[i].im).sqrt();
337                let expected_mag = gain_curve_f32[i];
338                (expected_mag > 0.01).then(|| (20.0f32 * (actual_mag / expected_mag).log10()).abs())
339            })
340            .fold(0.0, f32::max);
341
342        assert!(max_err_db < 0.05, "magnitude error: {max_err_db:.4}dB");
343    }
344
345    #[test]
346    fn min_phase_ir_non_power_of_two_has_reasonable_magnitude() {
347        let fft_size = 1535;
348        let num_bins = fft_size / 2 + 1;
349        let sample_rate = 48000.0;
350        let bin_freq = sample_rate / fft_size as f64;
351
352        let gain_curve: Vec<f64> = (0..num_bins)
353            .map(|i| {
354                let f = i as f64 * bin_freq;
355                let peak = 10.0f64.powf(4.0 / 20.0);
356                let width = 800.0 / 2.0;
357                let x = (f - 1800.0) / (width / 2.0);
358                1.0 + (peak - 1.0) / (1.0 + x * x)
359            })
360            .collect();
361
362        let ir = compute_min_phase_ir(&gain_curve, fft_size).expect("min phase ir");
363        let spectrum = forward_spectrum(&ir, fft_size);
364
365        let max_err_db = (1..num_bins - 1)
366            .filter_map(|i| {
367                let actual_mag =
368                    (spectrum[i].re * spectrum[i].re + spectrum[i].im * spectrum[i].im).sqrt();
369                let expected_mag = gain_curve[i];
370                (expected_mag > 0.01).then(|| (20.0 * (actual_mag / expected_mag).log10()).abs())
371            })
372            .fold(0.0, f64::max);
373
374        assert!(max_err_db < 0.03, "magnitude error: {max_err_db:.4}dB");
375    }
376}