Skip to main content

proteus_lib/dsp/effects/convolution_reverb/
convolution.rs

1//! Convolution engine used by reverb and offline processing.
2//!
3//! Two implementations are provided:
4//! - `complex_fft` (default): full complex FFT using `rustfft`.
5//! - `real_fft` (feature `real-fft`): real FFT using `realfft`.
6
7#[cfg(all(feature = "real-fft", feature = "complex-fft"))]
8compile_error!("Enable only one of: real-fft or complex-fft.");
9
10#[cfg(not(any(feature = "real-fft", feature = "complex-fft")))]
11compile_error!("Enable one of: real-fft or complex-fft.");
12
13#[cfg(feature = "complex-fft")]
14mod complex_fft {
15    use std::collections::VecDeque;
16    use std::sync::Arc;
17
18    use rustfft::{num_complex::Complex, Fft, FftPlanner};
19
20    // Taken from https://github.com/BordenJardine/reverb_vst
21
22    /// Overlap-add convolver based on complex FFTs.
23    #[derive(Clone)]
24    pub struct Convolver {
25        pub fft_size: usize,
26        ir_segments: Vec<Vec<Complex<f32>>>,
27        previous_frame_q: VecDeque<Vec<Complex<f32>>>,
28        pub previous_tail: Vec<f32>,
29        pending_output: Vec<f32>,
30        fft_processor: Arc<dyn Fft<f32>>,
31        ifft_processor: Arc<dyn Fft<f32>>,
32    }
33
34    impl Convolver {
35        /// Create a new convolver for a single-channel impulse response.
36        pub fn new(ir_signal: &[f32], fft_size: usize) -> Self {
37            let mut planner = FftPlanner::<f32>::new();
38            let fft_processor = planner.plan_fft_forward(fft_size);
39            let ifft_processor = planner.plan_fft_inverse(fft_size);
40
41            let ir_segments = segment_buffer(ir_signal, fft_size, &fft_processor);
42            let segment_count = ir_segments.len();
43            Self {
44                fft_size,
45                ir_segments,
46                fft_processor,
47                ifft_processor,
48                previous_frame_q: init_previous_frame_q(segment_count, fft_size),
49                previous_tail: init_previous_tail(fft_size / 2),
50                pending_output: Vec::new(),
51            }
52        }
53
54        /// Process a block of input samples and return the convolved output.
55        ///
56        /// The output length matches the input length. Internal tails are
57        /// preserved between calls.
58        pub fn process(&mut self, input_buffer: &[f32]) -> Vec<f32> {
59            let io_len = input_buffer.len();
60            let segment_size = self.fft_size / 2;
61            let input_segments = segment_buffer(input_buffer, self.fft_size, &self.fft_processor);
62
63            let mut output: Vec<f32> = Vec::with_capacity(io_len);
64            let norm = self.fft_size as f32;
65
66            if !self.pending_output.is_empty() {
67                let take = io_len.min(self.pending_output.len());
68                output.extend_from_slice(&self.pending_output[..take]);
69                self.pending_output.drain(0..take);
70            }
71
72            for segment in input_segments {
73                self.previous_frame_q.push_front(segment);
74                self.previous_frame_q.pop_back();
75
76                let mut convolved = self.convolve_frame();
77                self.ifft_processor.process(&mut convolved);
78
79                let mut time_domain: Vec<f32> = Vec::with_capacity(self.fft_size);
80                for sample in convolved {
81                    time_domain.push(sample.re / norm);
82                }
83
84                for i in 0..segment_size {
85                    if let Some(sample) = time_domain.get_mut(i) {
86                        *sample += self.previous_tail[i];
87                    }
88                }
89
90                self.previous_tail = time_domain[segment_size..self.fft_size].to_vec();
91                let remaining = io_len.saturating_sub(output.len());
92                if remaining == 0 {
93                    self.pending_output
94                        .extend_from_slice(&time_domain[0..segment_size]);
95                    continue;
96                }
97                if remaining >= segment_size {
98                    output.extend_from_slice(&time_domain[0..segment_size]);
99                } else {
100                    output.extend_from_slice(&time_domain[0..remaining]);
101                    self.pending_output
102                        .extend_from_slice(&time_domain[remaining..segment_size]);
103                }
104            }
105
106            output
107        }
108
109        /// Reset internal FFT history and tail buffers.
110        pub fn clear_state(&mut self) {
111            for frame in &mut self.previous_frame_q {
112                for sample in frame.iter_mut() {
113                    sample.re = 0.0;
114                    sample.im = 0.0;
115                }
116            }
117            self.previous_tail.fill(0.0);
118            self.pending_output.clear();
119        }
120
121        fn convolve_frame(&mut self) -> Vec<Complex<f32>> {
122            let mut convolved = vec![Complex { re: 0.0, im: 0.0 }; self.fft_size];
123
124            for i in 0..self.ir_segments.len() {
125                add_frames(
126                    &mut convolved,
127                    mult_frames(&self.previous_frame_q[i], &self.ir_segments[i]),
128                );
129            }
130            convolved
131        }
132    }
133
134    /// Add one spectrum frame into another in-place.
135    pub fn add_frames(f1: &mut [Complex<f32>], f2: Vec<Complex<f32>>) {
136        for (sample1, sample2) in f1.iter_mut().zip(f2) {
137            sample1.re = sample1.re + sample2.re;
138            sample1.im = sample1.im + sample2.im;
139        }
140    }
141
142    /// Multiply two spectra element-wise.
143    pub fn mult_frames(f1: &[Complex<f32>], f2: &[Complex<f32>]) -> Vec<Complex<f32>> {
144        let mut out: Vec<Complex<f32>> = Vec::new();
145        for (sample1, sample2) in f1.iter().zip(f2) {
146            out.push(Complex {
147                re: (sample1.re * sample2.re) - (sample1.im * sample2.im),
148                im: (sample1.im * sample2.re) + (sample1.re * sample2.im),
149            });
150        }
151        out
152    }
153
154    /// Initialize a zeroed overlap-add tail of the given size.
155    pub fn init_previous_tail(size: usize) -> Vec<f32> {
156        let mut tail = Vec::new();
157        for _ in 0..size {
158            tail.push(0.0);
159        }
160        tail
161    }
162
163    /// Segment a time-domain buffer into FFT-sized spectra.
164    pub fn segment_buffer(
165        buffer: &[f32],
166        fft_size: usize,
167        fft_processor: &Arc<dyn Fft<f32>>,
168    ) -> Vec<Vec<Complex<f32>>> {
169        let mut segments = Vec::new();
170        let segment_size = fft_size / 2;
171
172        let mut index = 0;
173        while index < buffer.len() {
174            let mut new_segment: Vec<Complex<f32>> = Vec::new();
175            for i in index..index + segment_size {
176                match buffer.get(i) {
177                    Some(sample) => new_segment.push(Complex {
178                        re: *sample,
179                        im: 0.0,
180                    }),
181                    None => continue,
182                }
183            }
184            while new_segment.len() < fft_size {
185                new_segment.push(Complex { re: 0.0, im: 0.0 });
186            }
187            fft_processor.process(&mut new_segment);
188            segments.push(new_segment);
189            index += segment_size;
190        }
191
192        segments
193    }
194
195    /// Build a queue of empty spectrum frames for overlap-add history.
196    pub fn init_previous_frame_q(
197        segment_count: usize,
198        fft_size: usize,
199    ) -> VecDeque<Vec<Complex<f32>>> {
200        let mut q = VecDeque::new();
201        for _ in 0..segment_count {
202            let mut empty = Vec::new();
203            for _ in 0..fft_size {
204                empty.push(Complex { re: 0.0, im: 0.0 });
205            }
206            q.push_back(empty);
207        }
208        q
209    }
210}
211
212#[cfg(feature = "real-fft")]
213mod real_fft {
214    use std::collections::VecDeque;
215    use std::sync::Arc;
216
217    use realfft::{num_complex::Complex, ComplexToReal, RealFftPlanner, RealToComplex};
218
219    /// Overlap-add convolver based on real FFTs.
220    #[derive(Clone)]
221    pub struct Convolver {
222        pub fft_size: usize,
223        ir_segments: Vec<Vec<Complex<f32>>>,
224        previous_frame_q: VecDeque<Vec<Complex<f32>>>,
225        pub previous_tail: Vec<f32>,
226        pending_output: Vec<f32>,
227        r2c: Arc<dyn RealToComplex<f32>>,
228        c2r: Arc<dyn ComplexToReal<f32>>,
229    }
230
231    impl Convolver {
232        /// Create a new convolver for a single-channel impulse response.
233        pub fn new(ir_signal: &[f32], fft_size: usize) -> Self {
234            let mut planner = RealFftPlanner::<f32>::new();
235            let r2c = planner.plan_fft_forward(fft_size);
236            let c2r = planner.plan_fft_inverse(fft_size);
237            let spectrum_len = (fft_size / 2) + 1;
238
239            let ir_segments = segment_buffer(ir_signal, fft_size, &r2c, spectrum_len);
240            let segment_count = ir_segments.len();
241            Self {
242                fft_size,
243                ir_segments,
244                r2c,
245                c2r,
246                previous_frame_q: init_previous_frame_q(segment_count, spectrum_len),
247                previous_tail: init_previous_tail(fft_size / 2),
248                pending_output: Vec::new(),
249            }
250        }
251
252        /// Process a block of input samples and return the convolved output.
253        ///
254        /// The output length matches the input length. Internal tails are
255        /// preserved between calls.
256        pub fn process(&mut self, input_buffer: &[f32]) -> Vec<f32> {
257            let io_len = input_buffer.len();
258            let segment_size = self.fft_size / 2;
259            let spectrum_len = self.ir_segments.first().map(|seg| seg.len()).unwrap_or(0);
260            let input_segments =
261                segment_buffer(input_buffer, self.fft_size, &self.r2c, spectrum_len);
262
263            let mut output: Vec<f32> = Vec::with_capacity(io_len);
264            let norm = self.fft_size as f32;
265            let spectrum_len = self.ir_segments.first().map(|seg| seg.len()).unwrap_or(0);
266
267            if !self.pending_output.is_empty() {
268                let take = io_len.min(self.pending_output.len());
269                output.extend_from_slice(&self.pending_output[..take]);
270                self.pending_output.drain(0..take);
271            }
272
273            for segment in input_segments {
274                self.previous_frame_q.push_front(segment);
275                self.previous_frame_q.pop_back();
276
277                let mut convolved = vec![Complex { re: 0.0, im: 0.0 }; spectrum_len];
278                for i in 0..self.ir_segments.len() {
279                    add_frames(
280                        &mut convolved,
281                        mult_frames(&self.previous_frame_q[i], &self.ir_segments[i]),
282                    );
283                }
284
285                let mut time_domain = vec![0.0_f32; self.fft_size];
286                self.c2r
287                    .process(&mut convolved, &mut time_domain)
288                    .expect("real IFFT failed");
289
290                for sample in &mut time_domain {
291                    *sample /= norm;
292                }
293
294                for i in 0..segment_size {
295                    time_domain[i] += self.previous_tail[i];
296                }
297
298                self.previous_tail = time_domain[segment_size..self.fft_size].to_vec();
299                let remaining = io_len.saturating_sub(output.len());
300                if remaining == 0 {
301                    self.pending_output
302                        .extend_from_slice(&time_domain[0..segment_size]);
303                    continue;
304                }
305                if remaining >= segment_size {
306                    output.extend_from_slice(&time_domain[0..segment_size]);
307                } else {
308                    output.extend_from_slice(&time_domain[0..remaining]);
309                    self.pending_output
310                        .extend_from_slice(&time_domain[remaining..segment_size]);
311                }
312            }
313
314            output
315        }
316
317        /// Reset internal FFT history and tail buffers.
318        pub fn clear_state(&mut self) {
319            for frame in &mut self.previous_frame_q {
320                for sample in frame.iter_mut() {
321                    sample.re = 0.0;
322                    sample.im = 0.0;
323                }
324            }
325            self.previous_tail.fill(0.0);
326            self.pending_output.clear();
327        }
328    }
329
330    fn add_frames(f1: &mut [Complex<f32>], f2: Vec<Complex<f32>>) {
331        for (sample1, sample2) in f1.iter_mut().zip(f2) {
332            sample1.re = sample1.re + sample2.re;
333            sample1.im = sample1.im + sample2.im;
334        }
335    }
336
337    fn mult_frames(f1: &[Complex<f32>], f2: &[Complex<f32>]) -> Vec<Complex<f32>> {
338        let mut out: Vec<Complex<f32>> = Vec::with_capacity(f1.len());
339        for (sample1, sample2) in f1.iter().zip(f2) {
340            out.push(Complex {
341                re: (sample1.re * sample2.re) - (sample1.im * sample2.im),
342                im: (sample1.im * sample2.re) + (sample1.re * sample2.im),
343            });
344        }
345        out
346    }
347
348    fn init_previous_tail(size: usize) -> Vec<f32> {
349        vec![0.0; size]
350    }
351
352    fn segment_buffer(
353        buffer: &[f32],
354        fft_size: usize,
355        r2c: &Arc<dyn RealToComplex<f32>>,
356        spectrum_len: usize,
357    ) -> Vec<Vec<Complex<f32>>> {
358        let mut segments = Vec::new();
359        let segment_size = fft_size / 2;
360
361        let mut index = 0;
362        while index < buffer.len() {
363            let mut time_domain = vec![0.0_f32; fft_size];
364            for (offset, sample) in buffer.iter().skip(index).take(segment_size).enumerate() {
365                time_domain[offset] = *sample;
366            }
367
368            let mut spectrum = vec![Complex { re: 0.0, im: 0.0 }; spectrum_len];
369            r2c.process(&mut time_domain, &mut spectrum)
370                .expect("real FFT failed");
371            segments.push(spectrum);
372            index += segment_size;
373        }
374
375        segments
376    }
377
378    fn init_previous_frame_q(
379        segment_count: usize,
380        spectrum_len: usize,
381    ) -> VecDeque<Vec<Complex<f32>>> {
382        let mut q = VecDeque::new();
383        for _ in 0..segment_count {
384            q.push_back(vec![Complex { re: 0.0, im: 0.0 }; spectrum_len]);
385        }
386        q
387    }
388}
389
390#[cfg(feature = "complex-fft")]
391pub use complex_fft::Convolver;
392
393#[cfg(feature = "real-fft")]
394pub use real_fft::Convolver;