Skip to main content

audio_engine_core/processor/
convolver.rs

1//! FFT-based convolution for long FIR filters (Overlap-Save algorithm)
2//!
3//! Zero-allocation real-time implementation with pre-allocated scratch buffers.
4
5use rustfft::{num_complex::Complex, FftPlanner};
6use std::sync::Arc;
7
8/// High-performance FFT convolver (Overlap-Save algorithm).
9///
10/// Zero-allocation implementation: all scratch buffers are pre-allocated at
11/// construction time so `process_into`/`process_inplace` are realtime-safe.
12pub struct FFTConvolver {
13    fft_size: usize,
14    impulse_response_fft: Vec<Vec<Complex<f64>>>, // one frequency-domain response per channel
15    overlap_buffers: Vec<Vec<f64>>,               // overlap buffer per channel
16    channels: usize,
17    ir_len: usize,
18    // Cached FFT plans to avoid recreating on each process call
19    fft_forward: Arc<dyn rustfft::Fft<f64>>,
20    fft_inverse: Arc<dyn rustfft::Fft<f64>>,
21    // Pre-allocated scratch buffers for zero-allocation processing
22    scratch_complex: Vec<Complex<f64>>,
23}
24
25impl Clone for FFTConvolver {
26    fn clone(&self) -> Self {
27        Self {
28            fft_size: self.fft_size,
29            impulse_response_fft: self.impulse_response_fft.clone(),
30            overlap_buffers: self.overlap_buffers.clone(),
31            channels: self.channels,
32            ir_len: self.ir_len,
33            fft_forward: Arc::clone(&self.fft_forward),
34            fft_inverse: Arc::clone(&self.fft_inverse),
35            scratch_complex: self.scratch_complex.clone(),
36        }
37    }
38}
39
40impl FFTConvolver {
41    /// Create a new FFT convolver with the given impulse response
42    ///
43    /// # Arguments
44    /// * `ir_data` - Impulse response samples in interleaved format [L0, R0, L1, R1, ...]
45    /// * `channels` - Number of channels
46    pub fn new(ir_data: &[f64], channels: usize) -> Self {
47        let ir_len_total = ir_data.len();
48        let ir_len_per_ch = ir_len_total / channels;
49
50        // Pick a suitable FFT size (a power of two larger than 2*ir_len).
51        let mut fft_size = 1;
52        while fft_size < (ir_len_per_ch * 2) {
53            fft_size <<= 1;
54        }
55
56        let mut planner = FftPlanner::new();
57        let fft = planner.plan_fft_forward(fft_size);
58
59        // Create cached plans for forward and inverse FFT
60        let fft_forward = planner.plan_fft_forward(fft_size);
61        let fft_inverse = planner.plan_fft_inverse(fft_size);
62
63        let mut ir_ffts = Vec::with_capacity(channels);
64        let mut overlap_bufs = Vec::with_capacity(channels);
65
66        for ch in 0..channels {
67            let mut buffer = vec![Complex::new(0.0, 0.0); fft_size];
68            // Load the IR for this channel and zero-pad the rest.
69            for i in 0..ir_len_per_ch {
70                buffer[i] = Complex::new(ir_data[i * channels + ch], 0.0);
71            }
72            fft.process(&mut buffer);
73            ir_ffts.push(buffer);
74            overlap_bufs.push(vec![0.0; ir_len_per_ch - 1]);
75        }
76
77        // Pre-allocate scratch buffer for FFT workspace
78        let scratch_complex = vec![Complex::new(0.0, 0.0); fft_size];
79
80        FFTConvolver {
81            fft_size,
82            impulse_response_fft: ir_ffts,
83            overlap_buffers: overlap_bufs,
84            channels,
85            ir_len: ir_len_per_ch,
86            fft_forward,
87            fft_inverse,
88            scratch_complex,
89        }
90    }
91
92    /// Get the IR length per channel
93    pub fn ir_length(&self) -> usize {
94        self.ir_len
95    }
96
97    /// Get the FFT size used
98    pub fn fft_size(&self) -> usize {
99        self.fft_size
100    }
101
102    /// Reset internal state (overlap buffers)
103    /// Call this when starting a new track to avoid artifacts
104    pub fn reset(&mut self) {
105        for overlap in &mut self.overlap_buffers {
106            overlap.fill(0.0);
107        }
108    }
109
110    #[inline]
111    #[allow(clippy::too_many_arguments)]
112    fn prepare_channel_chunk(
113        scratch: &mut [Complex<f64>],
114        overlap: &[f64],
115        input: &[f64],
116        channels: usize,
117        channel: usize,
118        processed_frames: usize,
119        chunk_len: usize,
120        ir_len: usize,
121    ) {
122        for i in 0..ir_len - 1 {
123            scratch[i] = Complex::new(overlap[i], 0.0);
124        }
125
126        for i in 0..chunk_len {
127            scratch[i + ir_len - 1] =
128                Complex::new(input[(processed_frames + i) * channels + channel], 0.0);
129        }
130        scratch[ir_len - 1 + chunk_len..].fill(Complex::new(0.0, 0.0));
131    }
132
133    #[inline]
134    fn update_channel_overlap(
135        overlap: &mut [f64],
136        input: &[f64],
137        channels: usize,
138        channel: usize,
139        processed_frames: usize,
140        chunk_len: usize,
141        ir_len: usize,
142    ) {
143        if chunk_len >= ir_len - 1 {
144            for i in 0..ir_len - 1 {
145                overlap[i] =
146                    input[(processed_frames + chunk_len - (ir_len - 1) + i) * channels + channel];
147            }
148        } else {
149            let shift = chunk_len;
150            let keep = ir_len - 1 - shift;
151            overlap.copy_within(shift..shift + keep, 0);
152            for i in 0..shift {
153                overlap[keep + i] = input[(processed_frames + i) * channels + channel];
154            }
155        }
156    }
157
158    #[inline]
159    #[allow(clippy::too_many_arguments)]
160    fn write_channel_output(
161        scratch: &[Complex<f64>],
162        output: &mut [f64],
163        channels: usize,
164        channel: usize,
165        processed_frames: usize,
166        chunk_len: usize,
167        ir_len: usize,
168        inv_n: f64,
169    ) {
170        for i in 0..chunk_len {
171            output[(processed_frames + i) * channels + channel] =
172                scratch[i + ir_len - 1].re * inv_n;
173        }
174    }
175
176    #[inline]
177    fn process_channel_chunk_fft(&mut self, channel: usize) {
178        self.fft_forward.process(&mut self.scratch_complex);
179
180        let ir_fft = &self.impulse_response_fft[channel];
181        multiply_spectrum_in_place(&mut self.scratch_complex, ir_fft);
182
183        self.fft_inverse.process(&mut self.scratch_complex);
184    }
185
186    /// Process audio block with zero allocation
187    ///
188    /// # Arguments
189    /// * `input` - Input samples in interleaved format
190    /// * `output` - Output buffer (must be same size as input)
191    ///
192    /// # Safety
193    /// This method is real-time safe: no heap allocations, no mutex, no syscalls
194    #[inline]
195    pub fn process_into(&mut self, input: &[f64], output: &mut [f64]) {
196        debug_assert_eq!(input.len(), output.len());
197
198        let channels = self.channels;
199        let total_frames = input.len() / channels;
200        let fft_size = self.fft_size;
201        let ir_len = self.ir_len;
202        let step_size = fft_size - ir_len + 1;
203        let inv_n = 1.0 / fft_size as f64;
204
205        // `total_frames` intentionally ignores an incomplete trailing frame.
206        // Keep that remainder deterministic without clearing the whole buffer.
207        output[total_frames * channels..].fill(0.0);
208
209        for ch in 0..channels {
210            let mut processed_frames = 0;
211
212            while processed_frames < total_frames {
213                let chunk_len = std::cmp::min(step_size, total_frames - processed_frames);
214
215                Self::prepare_channel_chunk(
216                    &mut self.scratch_complex,
217                    &self.overlap_buffers[ch],
218                    input,
219                    channels,
220                    ch,
221                    processed_frames,
222                    chunk_len,
223                    ir_len,
224                );
225                self.process_channel_chunk_fft(ch);
226                Self::write_channel_output(
227                    &self.scratch_complex,
228                    output,
229                    channels,
230                    ch,
231                    processed_frames,
232                    chunk_len,
233                    ir_len,
234                    inv_n,
235                );
236
237                Self::update_channel_overlap(
238                    &mut self.overlap_buffers[ch],
239                    input,
240                    channels,
241                    ch,
242                    processed_frames,
243                    chunk_len,
244                    ir_len,
245                );
246
247                processed_frames += chunk_len;
248            }
249        }
250    }
251
252    /// Process audio block, returning a new Vec (convenience wrapper)
253    ///
254    /// Note: This method allocates. For real-time use, prefer process_into().
255    pub fn process(&mut self, input: &[f64]) -> Vec<f64> {
256        let mut output = vec![0.0; input.len()];
257        self.process_into(input, &mut output);
258        output
259    }
260
261    /// Process audio block in-place with zero allocation
262    ///
263    /// Uses internal scratch buffer for temporary storage.
264    /// This is the recommended method for real-time audio processing.
265    ///
266    /// # Arguments
267    /// * `buf` - Input/output samples in interleaved format (modified in place)
268    #[inline]
269    pub fn process_inplace(&mut self, buf: &mut [f64]) {
270        // Use scratch_complex as temporary output buffer
271        // First, we need a separate buffer for output since we can't read and write the same location
272        // We'll use a two-phase approach: save input to scratch, process, write back
273
274        let channels = self.channels;
275        let total_frames = buf.len() / channels;
276        let fft_size = self.fft_size;
277        let ir_len = self.ir_len;
278        let step_size = fft_size - ir_len + 1;
279        let inv_n = 1.0 / fft_size as f64;
280
281        // We need a temporary buffer for output
282        // Re-purpose: use a separate approach - process channel by channel
283        // For each channel, we process and immediately write back
284
285        for ch in 0..channels {
286            let mut processed_frames = 0;
287
288            while processed_frames < total_frames {
289                let chunk_len = std::cmp::min(step_size, total_frames - processed_frames);
290
291                Self::prepare_channel_chunk(
292                    &mut self.scratch_complex,
293                    &self.overlap_buffers[ch],
294                    buf,
295                    channels,
296                    ch,
297                    processed_frames,
298                    chunk_len,
299                    ir_len,
300                );
301                self.process_channel_chunk_fft(ch);
302
303                // 6. Save original input for overlap BEFORE writing output
304                // (This is critical for inplace processing - we need the original input,
305                // not the processed output, for the next chunk's overlap)
306                Self::update_channel_overlap(
307                    &mut self.overlap_buffers[ch],
308                    buf,
309                    channels,
310                    ch,
311                    processed_frames,
312                    chunk_len,
313                    ir_len,
314                );
315
316                // 7. Write processed output to buffer
317                Self::write_channel_output(
318                    &self.scratch_complex,
319                    buf,
320                    channels,
321                    ch,
322                    processed_frames,
323                    chunk_len,
324                    ir_len,
325                    inv_n,
326                );
327
328                processed_frames += chunk_len;
329            }
330        }
331    }
332}
333
334#[inline]
335fn multiply_spectrum_in_place(samples: &mut [Complex<f64>], ir_fft: &[Complex<f64>]) {
336    for (sample, ir) in samples.iter_mut().zip(ir_fft) {
337        let re = sample.re * ir.re - sample.im * ir.im;
338        let im = sample.re * ir.im + sample.im * ir.re;
339        sample.re = re;
340        sample.im = im;
341    }
342}
343
344#[cfg(test)]
345mod tests {
346    use super::*;
347
348    #[test]
349    fn test_convolver_identity() {
350        // Identity impulse response [1.0, 0.0, 0.0, ...]
351        let ir = vec![1.0, 0.0, 0.0, 0.0]; // 4 taps mono
352        let mut conv = FFTConvolver::new(&ir, 1);
353
354        let input = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
355        let mut output = vec![0.0; input.len()];
356
357        conv.process_into(&input, &mut output);
358
359        // With identity IR, output should match input
360        for i in 0..input.len() {
361            assert!(
362                (output[i] - input[i]).abs() < 1e-10,
363                "Mismatch at {}: {} vs {}",
364                i,
365                output[i],
366                input[i]
367            );
368        }
369    }
370
371    #[test]
372    fn test_convolver_stereo() {
373        // Simple stereo IR
374        let ir = vec![1.0, 1.0, 0.0, 0.0]; // 2 taps stereo (both channels same)
375        let mut conv = FFTConvolver::new(&ir, 2);
376
377        let input = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
378        let mut output = vec![0.0; input.len()];
379
380        conv.process_into(&input, &mut output);
381
382        // Verify output is not all zeros
383        assert!(output.iter().any(|&x| x != 0.0));
384    }
385
386    #[test]
387    fn test_zero_allocation() {
388        let ir: Vec<f64> = (0..1024).map(|i| (i as f64 / 1024.0).sin()).collect();
389        let mut conv = FFTConvolver::new(&ir, 1);
390
391        let input = vec![0.5; 4096];
392        let mut output = vec![0.0; 4096];
393
394        // Multiple calls should not allocate
395        for _ in 0..100 {
396            conv.process_into(&input, &mut output);
397        }
398
399        // Just verify it doesn't crash
400        assert!(output.iter().any(|&x| x != 0.0));
401    }
402
403    // === FIX for Defect 8: Boundary unit tests for process_inplace ===
404
405    #[test]
406    fn test_inplace_identity() {
407        // Identity IR: process_inplace should preserve input
408        let ir = vec![1.0, 0.0, 0.0, 0.0]; // 4 taps mono
409        let mut conv = FFTConvolver::new(&ir, 1);
410
411        let original = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
412        let mut buf = original.clone();
413
414        conv.process_inplace(&mut buf);
415
416        for i in 0..original.len() {
417            assert!(
418                (buf[i] - original[i]).abs() < 1e-10,
419                "Inplace identity mismatch at {}: {} vs {}",
420                i,
421                buf[i],
422                original[i]
423            );
424        }
425    }
426
427    #[test]
428    fn test_inplace_matches_process_into() {
429        // Verify process_inplace produces same output as process_into
430        let ir: Vec<f64> = (0..32).map(|i| (i as f64 / 32.0).sin() * 0.1).collect();
431        let input: Vec<f64> = (0..256).map(|i| (i as f64 * 0.05).sin()).collect();
432
433        let mut conv1 = FFTConvolver::new(&ir, 1);
434        let mut conv2 = FFTConvolver::new(&ir, 1);
435
436        let mut output_into = vec![0.0; input.len()];
437        conv1.process_into(&input, &mut output_into);
438
439        let mut buf_inplace = input.clone();
440        conv2.process_inplace(&mut buf_inplace);
441
442        for i in 0..input.len() {
443            assert!(
444                (output_into[i] - buf_inplace[i]).abs() < 1e-10,
445                "Mismatch at {}: into={} vs inplace={}",
446                i,
447                output_into[i],
448                buf_inplace[i]
449            );
450        }
451    }
452
453    fn assert_processing_paths_equivalent(channels: usize, ir_frames: usize, input_frames: usize) {
454        let ir: Vec<f64> = (0..ir_frames * channels)
455            .map(|i| ((i + 1) as f64 * 0.17).sin() * 0.05)
456            .collect();
457        let input: Vec<f64> = (0..input_frames * channels)
458            .map(|i| ((i + 3) as f64 * 0.11).cos() * 0.5)
459            .collect();
460
461        let mut process_conv = FFTConvolver::new(&ir, channels);
462        let mut into_conv = FFTConvolver::new(&ir, channels);
463        let mut inplace_conv = FFTConvolver::new(&ir, channels);
464
465        let process_output = process_conv.process(&input);
466
467        let mut into_output = vec![f64::NAN; input.len()];
468        into_conv.process_into(&input, &mut into_output);
469
470        let mut inplace_output = input.clone();
471        inplace_conv.process_inplace(&mut inplace_output);
472
473        for i in 0..input.len() {
474            assert!(
475                (process_output[i] - into_output[i]).abs() < 1e-10,
476                "process/process_into mismatch at {i}: {} vs {}",
477                process_output[i],
478                into_output[i]
479            );
480            assert!(
481                (process_output[i] - inplace_output[i]).abs() < 1e-10,
482                "process/process_inplace mismatch at {i}: {} vs {}",
483                process_output[i],
484                inplace_output[i]
485            );
486        }
487    }
488
489    #[test]
490    fn test_processing_paths_equivalent_for_boundary_chunk_sizes() {
491        assert_processing_paths_equivalent(1, 8, 4);
492        assert_processing_paths_equivalent(2, 8, 8);
493        assert_processing_paths_equivalent(6, 8, 20);
494    }
495
496    #[test]
497    fn test_inplace_small_buffer() {
498        // Buffer smaller than IR length
499        let ir = vec![1.0, 0.5, 0.25, 0.125, 0.0, 0.0, 0.0, 0.0]; // 8 taps mono
500        let mut conv = FFTConvolver::new(&ir, 1);
501
502        // Only 4 samples (less than 8-tap IR)
503        let mut buf = vec![1.0, 0.0, 0.0, 0.0];
504        conv.process_inplace(&mut buf);
505
506        // Should produce convolution of delta with IR, truncated to 4 samples
507        // Result: [1.0, 0.5, 0.25, 0.125]
508        assert!((buf[0] - 1.0).abs() < 1e-10, "Expected 1.0, got {}", buf[0]);
509        assert!((buf[1] - 0.5).abs() < 1e-10, "Expected 0.5, got {}", buf[1]);
510        assert!(
511            (buf[2] - 0.25).abs() < 1e-10,
512            "Expected 0.25, got {}",
513            buf[2]
514        );
515        assert!(
516            (buf[3] - 0.125).abs() < 1e-10,
517            "Expected 0.125, got {}",
518            buf[3]
519        );
520    }
521
522    #[test]
523    fn test_inplace_stereo_identity() {
524        // Stereo identity IR
525        let ir = vec![1.0, 1.0, 0.0, 0.0]; // 2 taps stereo identity
526        let mut conv = FFTConvolver::new(&ir, 2);
527
528        let original = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; // 4 frames stereo
529        let mut buf = original.clone();
530
531        conv.process_inplace(&mut buf);
532
533        for i in 0..original.len() {
534            assert!(
535                (buf[i] - original[i]).abs() < 1e-10,
536                "Stereo inplace identity mismatch at {}: {} vs {}",
537                i,
538                buf[i],
539                original[i]
540            );
541        }
542    }
543
544    #[test]
545    fn test_inplace_multi_chunk() {
546        // Multiple consecutive calls with continuity
547        let ir = vec![1.0, 0.5, 0.0, 0.0]; // 4 taps mono
548        let mut conv = FFTConvolver::new(&ir, 1);
549
550        let mut buf1 = vec![1.0, 0.0, 0.0, 0.0];
551        conv.process_inplace(&mut buf1);
552
553        // Second chunk should carry overlap from first
554        let mut buf2 = vec![0.0, 0.0, 0.0, 0.0];
555        conv.process_inplace(&mut buf2);
556
557        // buf1 should be [1.0, 0.5, 0.0, 0.0]
558        assert!((buf1[0] - 1.0).abs() < 1e-10);
559        assert!((buf1[1] - 0.5).abs() < 1e-10);
560    }
561}