1use neco_complex::Complex;
2use neco_stft::{DspFloat, FftError};
3
4#[derive(Debug, Clone, PartialEq, Eq)]
5pub enum MinPhaseError {
6 InvalidGainCurveLen { expected: usize, got: usize },
7 Fft(FftError),
8}
9
10impl core::fmt::Display for MinPhaseError {
11 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
12 match self {
13 Self::InvalidGainCurveLen { expected, got } => {
14 write!(f, "wrong gain curve length: expected {expected}, got {got}")
15 }
16 Self::Fft(err) => err.fmt(f),
17 }
18 }
19}
20
21impl std::error::Error for MinPhaseError {}
22
23impl From<FftError> for MinPhaseError {
24 fn from(value: FftError) -> Self {
25 Self::Fft(value)
26 }
27}
28
29pub fn compute_min_phase_spectrum<T: DspFloat>(
30 gain_curve: &[T],
31 fft_size: usize,
32) -> Result<Vec<Complex<T>>, MinPhaseError> {
33 let num_bins = fft_size / 2 + 1;
34 if gain_curve.len() != num_bins {
35 return Err(MinPhaseError::InvalidGainCurveLen {
36 expected: num_bins,
37 got: gain_curve.len(),
38 });
39 }
40
41 let epsilon = T::from_f64(1e-20);
42 let two = T::from_f64(2.0);
43
44 T::with_fft_planner(|planner| {
45 let fft_fwd = planner.plan_fft_forward(fft_size);
46 let fft_inv = planner.plan_fft_inverse(fft_size);
47 let scale = T::one() / T::from_usize(fft_size);
48
49 let mut log_spectrum: Vec<Complex<T>> = gain_curve
50 .iter()
51 .map(|&gain| Complex::new(gain.max(epsilon).ln(), T::zero()))
52 .collect();
53
54 let mut cepstrum = fft_inv.make_output_vec();
55 fft_inv.process(&mut log_spectrum, &mut cepstrum)?;
56 for value in &mut cepstrum {
57 *value *= scale;
58 }
59
60 let mut cepstrum_min = vec![T::zero(); fft_size];
61 cepstrum_min[0] = cepstrum[0];
62 for i in 1..fft_size / 2 {
63 cepstrum_min[i] = two * cepstrum[i];
64 }
65 cepstrum_min[fft_size / 2] = cepstrum[fft_size / 2];
66
67 let mut min_log_spectrum = fft_fwd.make_output_vec();
68 fft_fwd.process(&mut cepstrum_min, &mut min_log_spectrum)?;
69
70 for bin in &mut min_log_spectrum {
71 let amplitude = bin.re.exp();
72 let phase = bin.im;
73 *bin = Complex::new(amplitude * phase.cos(), amplitude * phase.sin());
74 }
75
76 Ok(min_log_spectrum)
77 })
78}
79
80pub fn compute_min_phase_ir<T: DspFloat>(
81 gain_curve: &[T],
82 fft_size: usize,
83) -> Result<Vec<T>, MinPhaseError> {
84 let mut min_spectrum = compute_min_phase_spectrum(gain_curve, fft_size)?;
85 T::with_fft_planner(|planner| {
86 let fft_inv = planner.plan_fft_inverse(fft_size);
87 let scale = T::one() / T::from_usize(fft_size);
88 let mut ir = fft_inv.make_output_vec();
89 fft_inv.process(&mut min_spectrum, &mut ir)?;
90 for sample in &mut ir {
91 *sample *= scale;
92 }
93 Ok(ir)
94 })
95}
96
97pub fn convolve_ola<T: DspFloat>(input: &[T], ir: &[T]) -> Result<Vec<T>, MinPhaseError> {
98 let n = input.len();
99 let m = ir.len();
100 if n == 0 || m == 0 {
101 return Ok(vec![T::zero(); n]);
102 }
103
104 let block_size = m.next_power_of_two();
105 let conv_size = (block_size + m - 1).next_power_of_two();
106
107 T::with_fft_planner(|planner| {
108 let fft_fwd = planner.plan_fft_forward(conv_size);
109 let fft_inv = planner.plan_fft_inverse(conv_size);
110 let scale = T::one() / T::from_usize(conv_size);
111
112 let mut ir_padded = vec![T::zero(); conv_size];
113 ir_padded[..m].copy_from_slice(ir);
114 let mut ir_spectrum = fft_fwd.make_output_vec();
115 fft_fwd.process(&mut ir_padded, &mut ir_spectrum)?;
116
117 let mut output = vec![T::zero(); n];
118 let mut pos = 0usize;
119 let mut block = vec![T::zero(); conv_size];
120 let mut block_spectrum = fft_fwd.make_output_vec();
121 let mut result = fft_inv.make_output_vec();
122
123 while pos < n {
124 let end = (pos + block_size).min(n);
125 block.fill(T::zero());
126 block[..end - pos].copy_from_slice(&input[pos..end]);
127
128 fft_fwd.process(&mut block, &mut block_spectrum)?;
129 for (lhs, rhs) in block_spectrum.iter_mut().zip(ir_spectrum.iter()) {
130 let re = lhs.re * rhs.re - lhs.im * rhs.im;
131 let im = lhs.re * rhs.im + lhs.im * rhs.re;
132 lhs.re = re;
133 lhs.im = im;
134 }
135
136 fft_inv.process(&mut block_spectrum, &mut result)?;
137 for i in 0..conv_size {
138 if pos + i < n {
139 output[pos + i] += result[i] * scale;
140 }
141 }
142
143 pos += block_size;
144 }
145
146 Ok(output)
147 })
148}
149
150pub fn compute_blend_curve(
151 transient_map: &[f64],
152 lookahead_samples: usize,
153 smooth_samples: usize,
154 threshold: f64,
155) -> Vec<f64> {
156 let n = transient_map.len();
157 let mut raw_blend = vec![0.0; n];
158
159 for (i, &value) in transient_map.iter().enumerate() {
160 if value > threshold {
161 let start = i.saturating_sub(lookahead_samples);
162 let end = (i + lookahead_samples / 2).min(n);
163 for item in &mut raw_blend[start..end] {
164 *item = 1.0;
165 }
166 }
167 }
168
169 if smooth_samples < 2 {
170 return raw_blend;
171 }
172
173 let half = smooth_samples / 2;
174 let mut smoothed = vec![0.0; n];
175 let mut running_sum = 0.0;
176 for value in &raw_blend[..half.min(n)] {
177 running_sum += *value;
178 }
179
180 for (i, out) in smoothed.iter_mut().enumerate() {
181 let right = i + half;
182 if right < n {
183 running_sum += raw_blend[right];
184 }
185 if i > half + 1 {
186 let left = i - half - 1;
187 running_sum -= raw_blend[left];
188 }
189 let actual_window = (i + half + 1).min(n) - i.saturating_sub(half);
190 *out = (running_sum / actual_window as f64).clamp(0.0, 1.0);
191 }
192
193 smoothed
194}
195
196#[cfg(test)]
197mod tests {
198 use std::f64::consts::PI;
199
200 use neco_complex::Complex;
201 use neco_stft::{cast_vec, DspFloat};
202
203 use super::*;
204
205 fn forward_spectrum<T: DspFloat>(input: &[T], fft_size: usize) -> Vec<Complex<T>> {
206 T::with_fft_planner(|planner| {
207 let fft = planner.plan_fft_forward(fft_size);
208 let mut buffer = input.to_vec();
209 let mut spectrum = fft.make_output_vec();
210 fft.process(&mut buffer, &mut spectrum)
211 .expect("fft buffers from planner");
212 spectrum
213 })
214 }
215
216 #[test]
217 fn min_phase_rejects_wrong_gain_curve_len() {
218 let err = compute_min_phase_spectrum(&[1.0f64, 2.0], 8).expect_err("invalid len");
219 assert_eq!(
220 err,
221 MinPhaseError::InvalidGainCurveLen {
222 expected: 5,
223 got: 2,
224 }
225 );
226 }
227
228 #[test]
229 fn min_phase_ir_has_correct_magnitude() {
230 let fft_size = 4096;
231 let num_bins = fft_size / 2 + 1;
232 let sample_rate = 48000.0;
233 let bin_freq = sample_rate / fft_size as f64;
234
235 let gain_curve: Vec<f64> = (0..num_bins)
236 .map(|i| {
237 let f = i as f64 * bin_freq;
238 let a = 10.0f64.powf(6.0 / 20.0);
239 let bw = 1000.0 / 2.0;
240 let x = (f - 1000.0) / (bw / 2.0);
241 1.0 + (a - 1.0) / (1.0 + x * x)
242 })
243 .collect();
244
245 let ir = compute_min_phase_ir(&gain_curve, fft_size).expect("min phase ir");
246 let spectrum = forward_spectrum(&ir, fft_size);
247
248 let max_err_db = (1..num_bins - 1)
249 .filter_map(|i| {
250 let actual_mag =
251 (spectrum[i].re * spectrum[i].re + spectrum[i].im * spectrum[i].im).sqrt();
252 let expected_mag = gain_curve[i];
253 (expected_mag > 0.01).then(|| (20.0 * (actual_mag / expected_mag).log10()).abs())
254 })
255 .fold(0.0, f64::max);
256
257 assert!(max_err_db < 0.01, "magnitude error: {max_err_db:.4}dB");
258 }
259
260 #[test]
261 fn min_phase_ir_is_causal() {
262 let fft_size = 4096;
263 let num_bins = fft_size / 2 + 1;
264 let sample_rate = 48000.0;
265 let bin_freq = sample_rate / fft_size as f64;
266
267 let gain_curve: Vec<f64> = (0..num_bins)
268 .map(|i| {
269 let f = i as f64 * bin_freq;
270 let a = 10.0f64.powf(6.0 / 20.0);
271 let bw = 1000.0 / 2.0;
272 let x = (f - 1000.0) / (bw / 2.0);
273 1.0 + (a - 1.0) / (1.0 + x * x)
274 })
275 .collect();
276
277 let ir = compute_min_phase_ir(&gain_curve, fft_size).expect("min phase ir");
278 let quarter = fft_size / 4;
279 let energy_front: f64 = ir[..quarter].iter().map(|x| x * x).sum();
280 let energy_back: f64 = ir[3 * quarter..].iter().map(|x| x * x).sum();
281 assert!(energy_front > energy_back * 100.0);
282 }
283
284 #[test]
285 fn convolve_ola_identity() {
286 let n = 8192;
287 let input: Vec<f64> = (0..n)
288 .map(|i| (2.0 * PI * 440.0 * i as f64 / 48000.0).sin())
289 .collect();
290
291 let mut ir = vec![0.0; 256];
292 ir[0] = 1.0;
293 let output = convolve_ola(&input, &ir).expect("convolve");
294 let max_err = output
295 .iter()
296 .zip(input.iter())
297 .map(|(&o, &i)| (o - i).abs())
298 .fold(0.0, f64::max);
299 assert!(max_err < 1e-10, "identity error: {max_err:.2e}");
300 }
301
302 #[test]
303 fn blend_curve_stays_in_range() {
304 let transient_map = vec![0.0, 0.2, 0.9, 0.8, 0.1, 0.0];
305 let blend = compute_blend_curve(&transient_map, 2, 4, 0.3);
306 assert_eq!(blend.len(), transient_map.len());
307 assert!(blend.iter().all(|&value| (0.0..=1.0).contains(&value)));
308 assert!(blend[0] > 0.0);
309 assert!(blend[2] >= blend[5]);
310 }
311
312 #[test]
313 fn min_phase_ir_f32_has_reasonable_magnitude() {
314 let fft_size = 4096;
315 let num_bins = fft_size / 2 + 1;
316 let sample_rate = 48000.0;
317 let bin_freq = sample_rate / fft_size as f64;
318
319 let gain_curve_f64: Vec<f64> = (0..num_bins)
320 .map(|i| {
321 let f = i as f64 * bin_freq;
322 let a = 10.0f64.powf(6.0 / 20.0);
323 let bw = 1000.0 / 2.0;
324 let x = (f - 1000.0) / (bw / 2.0);
325 1.0 + (a - 1.0) / (1.0 + x * x)
326 })
327 .collect();
328 let gain_curve_f32: Vec<f32> = cast_vec(&gain_curve_f64);
329
330 let ir = compute_min_phase_ir(&gain_curve_f32, fft_size).expect("min phase ir");
331 let spectrum = forward_spectrum(&ir, fft_size);
332
333 let max_err_db = (1..num_bins - 1)
334 .filter_map(|i| {
335 let actual_mag =
336 (spectrum[i].re * spectrum[i].re + spectrum[i].im * spectrum[i].im).sqrt();
337 let expected_mag = gain_curve_f32[i];
338 (expected_mag > 0.01).then(|| (20.0f32 * (actual_mag / expected_mag).log10()).abs())
339 })
340 .fold(0.0, f32::max);
341
342 assert!(max_err_db < 0.05, "magnitude error: {max_err_db:.4}dB");
343 }
344
345 #[test]
346 fn min_phase_ir_non_power_of_two_has_reasonable_magnitude() {
347 let fft_size = 1535;
348 let num_bins = fft_size / 2 + 1;
349 let sample_rate = 48000.0;
350 let bin_freq = sample_rate / fft_size as f64;
351
352 let gain_curve: Vec<f64> = (0..num_bins)
353 .map(|i| {
354 let f = i as f64 * bin_freq;
355 let peak = 10.0f64.powf(4.0 / 20.0);
356 let width = 800.0 / 2.0;
357 let x = (f - 1800.0) / (width / 2.0);
358 1.0 + (peak - 1.0) / (1.0 + x * x)
359 })
360 .collect();
361
362 let ir = compute_min_phase_ir(&gain_curve, fft_size).expect("min phase ir");
363 let spectrum = forward_spectrum(&ir, fft_size);
364
365 let max_err_db = (1..num_bins - 1)
366 .filter_map(|i| {
367 let actual_mag =
368 (spectrum[i].re * spectrum[i].re + spectrum[i].im * spectrum[i].im).sqrt();
369 let expected_mag = gain_curve[i];
370 (expected_mag > 0.01).then(|| (20.0 * (actual_mag / expected_mag).log10()).abs())
371 })
372 .fold(0.0, f64::max);
373
374 assert!(max_err_db < 0.03, "magnitude error: {max_err_db:.4}dB");
375 }
376}