fft_convolver/
lib.rs

1#![doc = include_str!("../README.md")]
2
3mod fft;
4mod utilities;
5use crate::fft::Fft;
6use crate::utilities::{
7    complex_multiply_accumulate, complex_size, copy_and_pad, next_power_of_2, sum,
8};
9use realfft::num_complex::Complex;
10use realfft::num_traits::Zero;
11use realfft::{FftError, FftNum};
12use rtsan_standalone::nonblocking;
13use thiserror::Error;
14
15#[derive(Error, Debug)]
16pub enum FFTConvolverError {
17    #[error("block size is not allowed to be zero")]
18    BlockSizeZero,
19    #[error("impulse response exceeds configured capacity")]
20    ImpulseResponseExceedsCapacity,
21    #[error("error in fft: {0}")]
22    Fft(#[from] FftError),
23}
24
25/// FFTConvolver
26/// Implementation of a partitioned FFT convolution algorithm with uniform block size.
27///
28/// Some notes on how to use it:
29/// - After initialization with an impulse response, subsequent data portions of
30///   arbitrary length can be convolved. The convolver internally can handle
31///   this by using appropriate buffering.
32/// - The convolver works without "latency" (except for the required
33///   processing time, of course), i.e. the output always is the convolved
34///   input for each processing call.
35///
36/// - The convolver is suitable for real-time processing which means that no
37///   "unpredictable" operations like allocations, locking, API calls, etc. are
38///   performed during processing (all necessary allocations and preparations take
39///   place during initialization).
40#[derive(Clone)]
41pub struct FFTConvolver<F: FftNum> {
42    ir_len: usize,
43    block_size: usize,
44    seg_size: usize,
45    seg_count: usize,
46    active_seg_count: usize,
47    fft_complex_size: usize,
48    segments: Vec<Vec<Complex<F>>>,
49    segments_ir: Vec<Vec<Complex<F>>>,
50    fft_buffer: Vec<F>,
51    fft: Fft<F>,
52    pre_multiplied: Vec<Complex<F>>,
53    conv: Vec<Complex<F>>,
54    overlap: Vec<F>,
55    current: usize,
56    input_buffer: Vec<F>,
57    input_buffer_fill: usize,
58}
59
60impl<F: FftNum> Default for FFTConvolver<F> {
61    fn default() -> Self {
62        Self {
63            ir_len: Default::default(),
64            block_size: Default::default(),
65            seg_size: Default::default(),
66            seg_count: Default::default(),
67            active_seg_count: Default::default(),
68            fft_complex_size: Default::default(),
69            segments: Default::default(),
70            segments_ir: Default::default(),
71            fft_buffer: Default::default(),
72            fft: Default::default(),
73            pre_multiplied: Default::default(),
74            conv: Default::default(),
75            overlap: Default::default(),
76            current: Default::default(),
77            input_buffer: Default::default(),
78            input_buffer_fill: Default::default(),
79        }
80    }
81}
82
83impl<F: FftNum> FFTConvolver<F> {
84    /// Initializes the convolver with an impulse response
85    ///
86    /// This method sets up all internal buffers and prepares the convolver for processing.
87    /// The block size determines the internal partition size and affects efficiency.
88    /// It will be rounded up to the next power of 2.
89    ///
90    /// All memory allocations happen during initialization, making subsequent processing
91    /// operations real-time safe.
92    ///
93    /// # Arguments
94    ///
95    /// * `block_size` - Block size internally used by the convolver (partition size).
96    ///   Will be rounded up to the next power of 2. Must be > 0.
97    /// * `impulse_response` - The impulse response to convolve with. Can be empty.
98    ///
99    /// # Returns
100    ///
101    /// Returns `BlockSizeZero` if block_size is 0.
102    ///
103    /// # Example
104    ///
105    /// ```
106    /// use fft_convolver::FFTConvolver;
107    ///
108    /// let mut convolver = FFTConvolver::<f32>::default();
109    /// let ir = vec![0.5, 0.3, 0.2, 0.1];
110    /// convolver.init(128, &ir).unwrap();
111    /// ```
112    pub fn init(
113        &mut self,
114        block_size: usize,
115        impulse_response: &[F],
116    ) -> Result<(), FFTConvolverError> {
117        if block_size == 0 {
118            return Err(FFTConvolverError::BlockSizeZero);
119        }
120
121        self.ir_len = impulse_response.len();
122
123        if self.ir_len == 0 {
124            return Ok(());
125        }
126
127        self.block_size = next_power_of_2(block_size);
128        self.seg_size = 2 * self.block_size;
129        self.seg_count = (self.ir_len as f64 / self.block_size as f64).ceil() as usize;
130        self.active_seg_count = self.seg_count;
131        self.fft_complex_size = complex_size(self.seg_size);
132
133        // FFT
134        self.fft.init(self.seg_size);
135        self.fft_buffer = vec![F::zero(); self.seg_size];
136
137        // prepare segments
138        self.segments = vec![vec![Complex::zero(); self.fft_complex_size]; self.seg_count];
139
140        // prepare ir
141        self.segments_ir = vec![vec![Complex::zero(); self.fft_complex_size]; self.seg_count];
142        for (i, segment) in self.segments_ir.iter_mut().enumerate() {
143            let remaining = self.ir_len - (i * self.block_size);
144            let size_copy = if remaining >= self.block_size {
145                self.block_size
146            } else {
147                remaining
148            };
149            copy_and_pad(
150                &mut self.fft_buffer,
151                &impulse_response[i * self.block_size..],
152                size_copy,
153            );
154            self.fft.forward(&mut self.fft_buffer, segment)?;
155        }
156
157        // prepare convolution buffers
158        self.pre_multiplied = vec![Complex::zero(); self.fft_complex_size];
159        self.conv = vec![Complex::zero(); self.fft_complex_size];
160        self.overlap.resize(self.block_size, F::zero());
161
162        // prepare input buffer
163        self.input_buffer = vec![F::zero(); self.block_size];
164        self.input_buffer_fill = 0;
165
166        // reset current position
167        self.current = 0;
168
169        Ok(())
170    }
171
172    /// Updates the impulse response without reallocating buffers
173    ///
174    /// This method allows changing the impulse response at runtime while maintaining
175    /// real-time safety by avoiding allocations. The new impulse response must not
176    /// exceed the length of the original impulse response used during initialization.
177    ///
178    /// # Arguments
179    ///
180    /// * `impulse_response` - The new impulse response (must be ≤ original length)
181    ///
182    /// # Returns
183    ///
184    /// Returns `ImpulseResponseExceedsCapacity` if the new impulse response is longer
185    /// than the original one.
186    ///
187    /// # Example
188    ///
189    /// ```
190    /// use fft_convolver::FFTConvolver;
191    ///
192    /// let mut convolver = FFTConvolver::<f32>::default();
193    /// let ir1 = vec![0.5, 0.3, 0.2, 0.1];
194    /// convolver.init(4, &ir1).unwrap();
195    ///
196    /// // Update to a different impulse response of same or shorter length
197    /// let ir2 = vec![0.8, 0.6, 0.4];
198    /// convolver.set_response(&ir2).unwrap();
199    /// ```
200    #[nonblocking]
201    pub fn set_response(&mut self, impulse_response: &[F]) -> Result<(), FFTConvolverError> {
202        if impulse_response.len() > self.ir_len {
203            return Err(FFTConvolverError::ImpulseResponseExceedsCapacity);
204        }
205
206        self.fft_buffer.fill(F::zero());
207        self.conv.fill(Complex::zero());
208        self.pre_multiplied.fill(Complex::zero());
209        self.overlap.fill(F::zero());
210
211        self.active_seg_count =
212            (impulse_response.len() as f64 / self.block_size as f64).ceil() as usize;
213
214        // Prepare IR
215        for (i, segment) in self
216            .segments_ir
217            .iter_mut()
218            .enumerate()
219            .take(self.active_seg_count)
220        {
221            let remaining = impulse_response.len() - (i * self.block_size);
222            let size_copy = if remaining >= self.block_size {
223                self.block_size
224            } else {
225                remaining
226            };
227            copy_and_pad(
228                &mut self.fft_buffer,
229                &impulse_response[i * self.block_size..],
230                size_copy,
231            );
232            self.fft.forward(&mut self.fft_buffer, segment)?;
233        }
234
235        // Clear remaining segments
236        for segment in self.segments_ir.iter_mut().skip(self.active_seg_count) {
237            segment.fill(Complex::zero());
238        }
239
240        self.input_buffer.fill(F::zero());
241        self.input_buffer_fill = 0;
242        self.current = 0;
243
244        Ok(())
245    }
246
247    /// Convolves the input samples with the impulse response and outputs the result
248    ///
249    /// This is a real-time safe operation that performs no allocations. The input and
250    /// output buffers can be of any length. Internal buffering handles arbitrary sizes
251    /// and ensures the output is always properly aligned with the input (zero latency
252    /// except for processing time).
253    ///
254    /// If the convolver has no active impulse response, the output is filled with zeros.
255    ///
256    /// # Arguments
257    ///
258    /// * `input` - The input samples to convolve
259    /// * `output` - Buffer to write the convolution result. Must have the same length as `input`.
260    ///
261    /// # Returns
262    ///
263    /// Returns `Fft` error if an FFT operation fails.
264    ///
265    /// # Example
266    ///
267    /// ```
268    /// use fft_convolver::FFTConvolver;
269    ///
270    /// let mut convolver = FFTConvolver::<f32>::default();
271    /// let ir = vec![0.5, 0.3, 0.2];
272    /// convolver.init(128, &ir).unwrap();
273    ///
274    /// let input = vec![1.0; 256];
275    /// let mut output = vec![0.0; 256];
276    /// convolver.process(&input, &mut output).unwrap();
277    /// ```
278    #[nonblocking]
279    pub fn process(&mut self, input: &[F], output: &mut [F]) -> Result<(), FFTConvolverError> {
280        if self.active_seg_count == 0 {
281            output.fill(F::zero());
282            return Ok(());
283        }
284
285        let mut processed = 0;
286        while processed < output.len() {
287            let input_buffer_was_empty = self.input_buffer_fill == 0;
288            let processing = std::cmp::min(
289                output.len() - processed,
290                self.block_size - self.input_buffer_fill,
291            );
292
293            let input_buffer_pos = self.input_buffer_fill;
294            self.input_buffer[input_buffer_pos..input_buffer_pos + processing]
295                .copy_from_slice(&input[processed..processed + processing]);
296
297            // Forward FFT
298            copy_and_pad(&mut self.fft_buffer, &self.input_buffer, self.block_size);
299            if let Err(err) = self
300                .fft
301                .forward(&mut self.fft_buffer, &mut self.segments[self.current])
302            {
303                output.fill(F::zero());
304                return Err(err.into());
305            }
306
307            // complex multiplication
308            if input_buffer_was_empty {
309                self.pre_multiplied.fill(Complex::zero());
310                for i in 1..self.active_seg_count {
311                    let index_ir = i;
312                    let index_audio = (self.current + i) % self.active_seg_count;
313                    complex_multiply_accumulate(
314                        &mut self.pre_multiplied,
315                        &self.segments_ir[index_ir],
316                        &self.segments[index_audio],
317                    );
318                }
319            }
320            self.conv.copy_from_slice(&self.pre_multiplied);
321            complex_multiply_accumulate(
322                &mut self.conv,
323                &self.segments[self.current],
324                &self.segments_ir[0],
325            );
326
327            // Backward FFT
328            if let Err(err) = self.fft.inverse(&mut self.conv, &mut self.fft_buffer) {
329                output.fill(F::zero());
330                return Err(err.into());
331            }
332
333            // Add overlap
334            sum(
335                &mut output[processed..processed + processing],
336                &self.fft_buffer[input_buffer_pos..input_buffer_pos + processing],
337                &self.overlap[input_buffer_pos..input_buffer_pos + processing],
338            );
339
340            // Input buffer full => Next block
341            self.input_buffer_fill += processing;
342            if self.input_buffer_fill == self.block_size {
343                // Input buffer is empty again now
344                self.input_buffer.fill(F::zero());
345                self.input_buffer_fill = 0;
346                // Save the overlap
347                self.overlap
348                    .copy_from_slice(&self.fft_buffer[self.block_size..self.block_size * 2]);
349
350                // Update the current segment
351                self.current = if self.current > 0 {
352                    self.current - 1
353                } else {
354                    self.active_seg_count - 1
355                };
356            }
357            processed += processing;
358        }
359        Ok(())
360    }
361
362    /// Clears the internal processing state while preserving the impulse response
363    ///
364    /// This real-time safe operation resets all internal buffers that store the
365    /// convolution state, effectively removing any "history" or "tail" from previous
366    /// processing. The impulse response configuration remains intact, so processing
367    /// can continue immediately.
368    ///
369    /// This is useful when handling stream discontinuities such as:
370    /// - Seeking in audio playback
371    /// - Pause/resume operations with large time gaps
372    /// - Switching between different audio sources
373    ///
374    /// After calling `reset()`, the next `process()` call will produce output as if
375    /// the convolver had just been initialized.
376    ///
377    /// # Example
378    ///
379    /// ```
380    /// use fft_convolver::FFTConvolver;
381    ///
382    /// let mut convolver = FFTConvolver::<f32>::default();
383    /// let ir = vec![0.5, 0.3, 0.2];
384    /// convolver.init(128, &ir).unwrap();
385    ///
386    /// let input = vec![1.0; 256];
387    /// let mut output = vec![0.0; 256];
388    /// convolver.process(&input, &mut output).unwrap();
389    ///
390    /// // Clear the state when seeking to a new position
391    /// convolver.reset();
392    ///
393    /// // Continue processing with fresh state
394    /// convolver.process(&input, &mut output).unwrap();
395    /// ```
396    #[nonblocking]
397    pub fn reset(&mut self) {
398        self.input_buffer.fill(F::zero());
399        self.input_buffer_fill = 0;
400
401        self.fft_buffer.fill(F::zero());
402        for segment in &mut self.segments {
403            segment.fill(Complex::zero());
404        }
405
406        self.conv.fill(Complex::zero());
407        self.pre_multiplied.fill(Complex::zero());
408
409        self.overlap.fill(F::zero());
410        self.current = 0;
411    }
412}
413
414// Tests
415#[cfg(test)]
416mod tests {
417    use crate::{FFTConvolver, FFTConvolverError};
418
419    #[test]
420    fn init_test() {
421        let mut convolver = FFTConvolver::default();
422        let ir = vec![1., 0., 0., 0.];
423        convolver.init(10, &ir).unwrap();
424
425        assert_eq!(convolver.ir_len, 4);
426        assert_eq!(convolver.block_size, 16);
427        assert_eq!(convolver.seg_size, 32);
428        assert_eq!(convolver.seg_count, 1);
429        assert_eq!(convolver.active_seg_count, 1);
430        assert_eq!(convolver.fft_complex_size, 17);
431
432        assert_eq!(convolver.segments.len(), 1);
433        assert_eq!(convolver.segments.first().unwrap().len(), 17);
434        for seg in &convolver.segments {
435            for num in seg {
436                assert_eq!(num.re, 0.);
437                assert_eq!(num.im, 0.);
438            }
439        }
440
441        assert_eq!(convolver.segments_ir.len(), 1);
442        assert_eq!(convolver.segments_ir.first().unwrap().len(), 17);
443        for seg in &convolver.segments_ir {
444            for num in seg {
445                assert_eq!(num.re, 1.);
446                assert_eq!(num.im, 0.);
447            }
448        }
449
450        assert_eq!(convolver.fft_buffer.len(), 32);
451        assert_eq!(*convolver.fft_buffer.first().unwrap(), 1.);
452        for i in 1..convolver.fft_buffer.len() {
453            assert_eq!(convolver.fft_buffer[i], 0.);
454        }
455
456        assert_eq!(convolver.pre_multiplied.len(), 17);
457        for num in &convolver.pre_multiplied {
458            assert_eq!(num.re, 0.);
459            assert_eq!(num.im, 0.);
460        }
461
462        assert_eq!(convolver.conv.len(), 17);
463        for num in &convolver.conv {
464            assert_eq!(num.re, 0.);
465            assert_eq!(num.im, 0.);
466        }
467
468        assert_eq!(convolver.overlap.len(), 16);
469        for num in &convolver.overlap {
470            assert_eq!(*num, 0.);
471        }
472
473        assert_eq!(convolver.input_buffer.len(), 16);
474        for num in &convolver.input_buffer {
475            assert_eq!(*num, 0.);
476        }
477
478        assert_eq!(convolver.input_buffer_fill, 0);
479    }
480
481    #[test]
482    fn process_test() {
483        let mut convolver = FFTConvolver::<f32>::default();
484        let ir = vec![1., 0., 0., 0.];
485        convolver.init(2, &ir).unwrap();
486
487        let input = vec![0., 1., 2., 3.];
488        let mut output = vec![0.; 4];
489        convolver.process(&input, &mut output).unwrap();
490
491        for i in 0..output.len() {
492            assert_eq!(input[i], output[i]);
493        }
494    }
495
496    #[test]
497    fn reset_test() {
498        // Create an impulse response with actual filtering characteristics
499        let ir = vec![0.5, 0.3, 0.2, 0.1];
500        let block_size = 4;
501
502        // First convolver: process data, then clear, then process again
503        let mut convolver1 = FFTConvolver::<f32>::default();
504        convolver1.init(block_size, &ir).unwrap();
505
506        // Process some data to build up history
507        let history_input = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
508        let mut history_output = vec![0.0; 8];
509        convolver1
510            .process(&history_input, &mut history_output)
511            .unwrap();
512
513        // Clear the history
514        convolver1.reset();
515
516        // Process fresh data after clearing
517        let test_input = vec![1.0, 1.0, 1.0, 1.0];
518        let mut output1 = vec![0.0; 4];
519        convolver1.process(&test_input, &mut output1).unwrap();
520
521        // Second convolver: freshly initialized, process the same data
522        let mut convolver2 = FFTConvolver::<f32>::default();
523        convolver2.init(block_size, &ir).unwrap();
524        let mut output2 = vec![0.0; 4];
525        convolver2.process(&test_input, &mut output2).unwrap();
526
527        // The outputs should be identical if clear() truly cleared all history
528        for i in 0..output1.len() {
529            assert!(
530                (output1[i] - output2[i]).abs() < 1e-5,
531                "Mismatch at index {}: cleared convolver produced {}, fresh convolver produced {}",
532                i,
533                output1[i],
534                output2[i]
535            );
536        }
537    }
538
539    #[test]
540    fn reset_preserves_configuration() {
541        // Test that clear() preserves the convolver configuration
542        let ir = vec![0.5, 0.3, 0.2, 0.1];
543        let block_size = 4;
544
545        let mut convolver = FFTConvolver::<f32>::default();
546        convolver.init(block_size, &ir).unwrap();
547
548        let ir_len = convolver.ir_len;
549        let block_size_actual = convolver.block_size;
550        let seg_size = convolver.seg_size;
551        let seg_count = convolver.seg_count;
552
553        // Process some data
554        let input = vec![1.0, 2.0, 3.0, 4.0];
555        let mut output = vec![0.0; 4];
556        convolver.process(&input, &mut output).unwrap();
557
558        // Clear
559        convolver.reset();
560
561        // Configuration should be unchanged
562        assert_eq!(convolver.ir_len, ir_len);
563        assert_eq!(convolver.block_size, block_size_actual);
564        assert_eq!(convolver.seg_size, seg_size);
565        assert_eq!(convolver.seg_count, seg_count);
566    }
567
568    #[test]
569    fn set_response_equals_init() {
570        // Test that set_response produces the same results as init
571        let ir1 = vec![0.5, 0.3, 0.2, 0.1];
572        let ir2 = vec![0.8, 0.6, 0.4, 0.2];
573        let block_size = 4;
574
575        // Convolver 1: Initialize with ir1, then set_response to ir2
576        let mut convolver1 = FFTConvolver::<f32>::default();
577        convolver1.init(block_size, &ir1).unwrap();
578        convolver1.set_response(&ir2).unwrap();
579
580        // Convolver 2: Initialize directly with ir2
581        let mut convolver2 = FFTConvolver::<f32>::default();
582        convolver2.init(block_size, &ir2).unwrap();
583
584        // Process the same input with both convolvers
585        let input = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
586        let mut output1 = vec![0.0; 8];
587        let mut output2 = vec![0.0; 8];
588
589        convolver1.process(&input, &mut output1).unwrap();
590        convolver2.process(&input, &mut output2).unwrap();
591
592        // The outputs should be identical
593        for i in 0..output1.len() {
594            assert!(
595                (output1[i] - output2[i]).abs() < 1e-5,
596                "Mismatch at index {}: set_response produced {}, init produced {}",
597                i,
598                output1[i],
599                output2[i]
600            );
601        }
602    }
603
604    #[test]
605    fn set_response_with_shorter_ir() {
606        // Test that set_response works correctly with a shorter impulse response
607        let ir1 = vec![0.5, 0.3, 0.2, 0.1, 0.05, 0.02];
608        let ir2 = vec![0.8, 0.6, 0.4];
609        let block_size = 4;
610
611        // Initialize with longer IR, then set to shorter IR
612        let mut convolver1 = FFTConvolver::<f32>::default();
613        convolver1.init(block_size, &ir1).unwrap();
614        convolver1.set_response(&ir2).unwrap();
615
616        // Initialize directly with shorter IR
617        let mut convolver2 = FFTConvolver::<f32>::default();
618        convolver2.init(block_size, &ir2).unwrap();
619
620        // Process the same input
621        let input = vec![1.0, 1.0, 1.0, 1.0];
622        let mut output1 = vec![0.0; 4];
623        let mut output2 = vec![0.0; 4];
624
625        convolver1.process(&input, &mut output1).unwrap();
626        convolver2.process(&input, &mut output2).unwrap();
627
628        // The outputs should be identical
629        for i in 0..output1.len() {
630            assert!(
631                (output1[i] - output2[i]).abs() < 1e-5,
632                "Mismatch at index {}: set_response produced {}, init produced {}",
633                i,
634                output1[i],
635                output2[i]
636            );
637        }
638    }
639
640    #[test]
641    fn set_response_too_long_returns_error() {
642        // Test that set_response returns an error when IR is too long
643        let ir1 = vec![0.5, 0.3, 0.2, 0.1];
644        let ir2 = vec![0.8, 0.6, 0.4, 0.2, 0.1, 0.05];
645        let block_size = 4;
646
647        let mut convolver = FFTConvolver::<f32>::default();
648        convolver.init(block_size, &ir1).unwrap();
649
650        // Attempting to set a longer IR should fail
651        let result = convolver.set_response(&ir2);
652        assert!(result.is_err());
653        assert!(matches!(
654            result.unwrap_err(),
655            FFTConvolverError::ImpulseResponseExceedsCapacity
656        ));
657    }
658
659    #[test]
660    fn test_zero_latency() {
661        // Test that the algorithm has zero latency (no algorithmic delay)
662        // An impulse at input[0] should produce output starting at output[0]
663        let mut convolver = FFTConvolver::<f32>::default();
664
665        // Use a simple impulse response: just pass through with some gain
666        let ir = vec![0.5, 0.3, 0.2, 0.1];
667        convolver.init(4, &ir).unwrap();
668
669        // Send an impulse at the very first sample
670        let mut input = vec![0.0; 16];
671        input[0] = 1.0; // Impulse at position 0
672
673        let mut output = vec![0.0; 16];
674        convolver.process(&input, &mut output).unwrap();
675
676        // Check that the first output sample has the impulse response
677        // If there were latency, output[0] would be 0.0
678        assert!(
679            output[0].abs() > 0.0,
680            "Output[0] should be non-zero, indicating zero latency. Got: {}",
681            output[0]
682        );
683
684        // Verify the output matches the impulse response
685        assert!(
686            (output[0] - 0.5).abs() < 1e-5,
687            "output[0] should be 0.5, got {}",
688            output[0]
689        );
690        assert!(
691            (output[1] - 0.3).abs() < 1e-5,
692            "output[1] should be 0.3, got {}",
693            output[1]
694        );
695        assert!(
696            (output[2] - 0.2).abs() < 1e-5,
697            "output[2] should be 0.2, got {}",
698            output[2]
699        );
700        assert!(
701            (output[3] - 0.1).abs() < 1e-5,
702            "output[3] should be 0.1, got {}",
703            output[3]
704        );
705    }
706}