Skip to main content

proteus_lib/dsp/
convolution.rs

1//! Convolution engine used by reverb and offline processing.
2//!
3//! Two implementations are provided:
4//! - `complex_fft` (default): full complex FFT using `rustfft`.
5//! - `real_fft` (feature `real-fft`): real FFT using `realfft`.
6
7#[cfg(not(feature = "real-fft"))]
8mod complex_fft {
9    use std::collections::VecDeque;
10    use std::sync::Arc;
11
12    use rustfft::{num_complex::Complex, Fft, FftPlanner};
13
14    // Taken from https://github.com/BordenJardine/reverb_vst
15
16    /// Overlap-add convolver based on complex FFTs.
17    #[derive(Clone)]
18    pub struct Convolver {
19        pub fft_size: usize,
20        ir_segments: Vec<Vec<Complex<f32>>>,
21        previous_frame_q: VecDeque<Vec<Complex<f32>>>,
22        pub previous_tail: Vec<f32>,
23        pending_output: Vec<f32>,
24        fft_processor: Arc<dyn Fft<f32>>,
25        ifft_processor: Arc<dyn Fft<f32>>,
26    }
27
28    impl Convolver {
29        /// Create a new convolver for a single-channel impulse response.
30        pub fn new(ir_signal: &[f32], fft_size: usize) -> Self {
31            let mut planner = FftPlanner::<f32>::new();
32            let fft_processor = planner.plan_fft_forward(fft_size);
33            let ifft_processor = planner.plan_fft_inverse(fft_size);
34
35            let ir_segments = segment_buffer(ir_signal, fft_size, &fft_processor);
36            let segment_count = ir_segments.len();
37            Self {
38                fft_size,
39                ir_segments,
40                fft_processor,
41                ifft_processor,
42                previous_frame_q: init_previous_frame_q(segment_count, fft_size),
43                previous_tail: init_previous_tail(fft_size / 2),
44                pending_output: Vec::new(),
45            }
46        }
47
48        /// Process a block of input samples and return the convolved output.
49        ///
50        /// The output length matches the input length. Internal tails are
51        /// preserved between calls.
52        pub fn process(&mut self, input_buffer: &[f32]) -> Vec<f32> {
53            let io_len = input_buffer.len();
54            let segment_size = self.fft_size / 2;
55            let input_segments = segment_buffer(input_buffer, self.fft_size, &self.fft_processor);
56
57            let mut output: Vec<f32> = Vec::with_capacity(io_len);
58            let norm = self.fft_size as f32;
59
60            if !self.pending_output.is_empty() {
61                let take = io_len.min(self.pending_output.len());
62                output.extend_from_slice(&self.pending_output[..take]);
63                self.pending_output.drain(0..take);
64            }
65
66            for segment in input_segments {
67                self.previous_frame_q.push_front(segment);
68                self.previous_frame_q.pop_back();
69
70                let mut convolved = self.convolve_frame();
71                self.ifft_processor.process(&mut convolved);
72
73                let mut time_domain: Vec<f32> = Vec::with_capacity(self.fft_size);
74                for sample in convolved {
75                    time_domain.push(sample.re / norm);
76                }
77
78                for i in 0..segment_size {
79                    if let Some(sample) = time_domain.get_mut(i) {
80                        *sample += self.previous_tail[i];
81                    }
82                }
83
84                self.previous_tail = time_domain[segment_size..self.fft_size].to_vec();
85                let remaining = io_len.saturating_sub(output.len());
86                if remaining == 0 {
87                    self.pending_output
88                        .extend_from_slice(&time_domain[0..segment_size]);
89                    continue;
90                }
91                if remaining >= segment_size {
92                    output.extend_from_slice(&time_domain[0..segment_size]);
93                } else {
94                    output.extend_from_slice(&time_domain[0..remaining]);
95                    self.pending_output
96                        .extend_from_slice(&time_domain[remaining..segment_size]);
97                }
98            }
99
100            output
101        }
102
103        /// Reset internal FFT history and tail buffers.
104        pub fn clear_state(&mut self) {
105            for frame in &mut self.previous_frame_q {
106                for sample in frame.iter_mut() {
107                    sample.re = 0.0;
108                    sample.im = 0.0;
109                }
110            }
111            self.previous_tail.fill(0.0);
112            self.pending_output.clear();
113        }
114
115        fn convolve_frame(&mut self) -> Vec<Complex<f32>> {
116            let mut convolved = vec![Complex { re: 0.0, im: 0.0 }; self.fft_size];
117
118            for i in 0..self.ir_segments.len() {
119                add_frames(
120                    &mut convolved,
121                    mult_frames(&self.previous_frame_q[i], &self.ir_segments[i]),
122                );
123            }
124            convolved
125        }
126    }
127
128    /// Add one spectrum frame into another in-place.
129    pub fn add_frames(f1: &mut [Complex<f32>], f2: Vec<Complex<f32>>) {
130        for (sample1, sample2) in f1.iter_mut().zip(f2) {
131            sample1.re = sample1.re + sample2.re;
132            sample1.im = sample1.im + sample2.im;
133        }
134    }
135
136    /// Multiply two spectra element-wise.
137    pub fn mult_frames(f1: &[Complex<f32>], f2: &[Complex<f32>]) -> Vec<Complex<f32>> {
138        let mut out: Vec<Complex<f32>> = Vec::new();
139        for (sample1, sample2) in f1.iter().zip(f2) {
140            out.push(Complex {
141                re: (sample1.re * sample2.re) - (sample1.im * sample2.im),
142                im: (sample1.im * sample2.re) + (sample1.re * sample2.im),
143            });
144        }
145        out
146    }
147
148    /// Initialize a zeroed overlap-add tail of the given size.
149    pub fn init_previous_tail(size: usize) -> Vec<f32> {
150        let mut tail = Vec::new();
151        for _ in 0..size {
152            tail.push(0.0);
153        }
154        tail
155    }
156
157    /// Segment a time-domain buffer into FFT-sized spectra.
158    pub fn segment_buffer(
159        buffer: &[f32],
160        fft_size: usize,
161        fft_processor: &Arc<dyn Fft<f32>>,
162    ) -> Vec<Vec<Complex<f32>>> {
163        let mut segments = Vec::new();
164        let segment_size = fft_size / 2;
165
166        let mut index = 0;
167        while index < buffer.len() {
168            let mut new_segment: Vec<Complex<f32>> = Vec::new();
169            for i in index..index + segment_size {
170                match buffer.get(i) {
171                    Some(sample) => new_segment.push(Complex { re: *sample, im: 0.0 }),
172                    None => continue,
173                }
174            }
175            while new_segment.len() < fft_size {
176                new_segment.push(Complex { re: 0.0, im: 0.0 });
177            }
178            fft_processor.process(&mut new_segment);
179            segments.push(new_segment);
180            index += segment_size;
181        }
182
183        segments
184    }
185
186    /// Build a queue of empty spectrum frames for overlap-add history.
187    pub fn init_previous_frame_q(
188        segment_count: usize,
189        fft_size: usize,
190    ) -> VecDeque<Vec<Complex<f32>>> {
191        let mut q = VecDeque::new();
192        for _ in 0..segment_count {
193            let mut empty = Vec::new();
194            for _ in 0..fft_size {
195                empty.push(Complex { re: 0.0, im: 0.0 });
196            }
197            q.push_back(empty);
198        }
199        q
200    }
201}
202
203#[cfg(feature = "real-fft")]
204mod real_fft {
205    use std::collections::VecDeque;
206    use std::sync::Arc;
207
208    use rustfft::num_complex::Complex;
209    use realfft::{ComplexToReal, RealFftPlanner, RealToComplex};
210
211    /// Overlap-add convolver based on real FFTs.
212    #[derive(Clone)]
213    pub struct Convolver {
214        pub fft_size: usize,
215        ir_segments: Vec<Vec<Complex<f32>>>,
216        previous_frame_q: VecDeque<Vec<Complex<f32>>>,
217        pub previous_tail: Vec<f32>,
218        pending_output: Vec<f32>,
219        r2c: Arc<dyn RealToComplex<f32>>,
220        c2r: Arc<dyn ComplexToReal<f32>>,
221    }
222
223    impl Convolver {
224        /// Create a new convolver for a single-channel impulse response.
225        pub fn new(ir_signal: &[f32], fft_size: usize) -> Self {
226            let mut planner = RealFftPlanner::<f32>::new();
227            let r2c = planner.plan_fft_forward(fft_size);
228            let c2r = planner.plan_fft_inverse(fft_size);
229            let spectrum_len = (fft_size / 2) + 1;
230
231            let ir_segments = segment_buffer(ir_signal, fft_size, &r2c, spectrum_len);
232            let segment_count = ir_segments.len();
233            Self {
234                fft_size,
235                ir_segments,
236                r2c,
237                c2r,
238                previous_frame_q: init_previous_frame_q(segment_count, spectrum_len),
239                previous_tail: init_previous_tail(fft_size / 2),
240                pending_output: Vec::new(),
241            }
242        }
243
244        /// Process a block of input samples and return the convolved output.
245        ///
246        /// The output length matches the input length. Internal tails are
247        /// preserved between calls.
248        pub fn process(&mut self, input_buffer: &[f32]) -> Vec<f32> {
249            let io_len = input_buffer.len();
250            let segment_size = self.fft_size / 2;
251            let spectrum_len = self.ir_segments.first().map(|seg| seg.len()).unwrap_or(0);
252            let input_segments =
253                segment_buffer(input_buffer, self.fft_size, &self.r2c, spectrum_len);
254
255            let mut output: Vec<f32> = Vec::with_capacity(io_len);
256            let norm = self.fft_size as f32;
257            let spectrum_len = self.ir_segments.first().map(|seg| seg.len()).unwrap_or(0);
258
259            if !self.pending_output.is_empty() {
260                let take = io_len.min(self.pending_output.len());
261                output.extend_from_slice(&self.pending_output[..take]);
262                self.pending_output.drain(0..take);
263            }
264
265            for segment in input_segments {
266                self.previous_frame_q.push_front(segment);
267                self.previous_frame_q.pop_back();
268
269                let mut convolved = vec![Complex { re: 0.0, im: 0.0 }; spectrum_len];
270                for i in 0..self.ir_segments.len() {
271                    add_frames(
272                        &mut convolved,
273                        mult_frames(&self.previous_frame_q[i], &self.ir_segments[i]),
274                    );
275                }
276
277                let mut time_domain = vec![0.0_f32; self.fft_size];
278                self.c2r
279                    .process(&mut convolved, &mut time_domain)
280                    .expect("real IFFT failed");
281
282                for sample in &mut time_domain {
283                    *sample /= norm;
284                }
285
286                for i in 0..segment_size {
287                    time_domain[i] += self.previous_tail[i];
288                }
289
290                self.previous_tail = time_domain[segment_size..self.fft_size].to_vec();
291                let remaining = io_len.saturating_sub(output.len());
292                if remaining == 0 {
293                    self.pending_output
294                        .extend_from_slice(&time_domain[0..segment_size]);
295                    continue;
296                }
297                if remaining >= segment_size {
298                    output.extend_from_slice(&time_domain[0..segment_size]);
299                } else {
300                    output.extend_from_slice(&time_domain[0..remaining]);
301                    self.pending_output
302                        .extend_from_slice(&time_domain[remaining..segment_size]);
303                }
304            }
305
306            output
307        }
308
309        /// Reset internal FFT history and tail buffers.
310        pub fn clear_state(&mut self) {
311            for frame in &mut self.previous_frame_q {
312                for sample in frame.iter_mut() {
313                    sample.re = 0.0;
314                    sample.im = 0.0;
315                }
316            }
317            self.previous_tail.fill(0.0);
318            self.pending_output.clear();
319        }
320    }
321
322    fn add_frames(f1: &mut [Complex<f32>], f2: Vec<Complex<f32>>) {
323        for (sample1, sample2) in f1.iter_mut().zip(f2) {
324            sample1.re = sample1.re + sample2.re;
325            sample1.im = sample1.im + sample2.im;
326        }
327    }
328
329    fn mult_frames(f1: &[Complex<f32>], f2: &[Complex<f32>]) -> Vec<Complex<f32>> {
330        let mut out: Vec<Complex<f32>> = Vec::with_capacity(f1.len());
331        for (sample1, sample2) in f1.iter().zip(f2) {
332            out.push(Complex {
333                re: (sample1.re * sample2.re) - (sample1.im * sample2.im),
334                im: (sample1.im * sample2.re) + (sample1.re * sample2.im),
335            });
336        }
337        out
338    }
339
340    fn init_previous_tail(size: usize) -> Vec<f32> {
341        vec![0.0; size]
342    }
343
344    fn segment_buffer(
345        buffer: &[f32],
346        fft_size: usize,
347        r2c: &Arc<dyn RealToComplex<f32>>,
348        spectrum_len: usize,
349    ) -> Vec<Vec<Complex<f32>>> {
350        let mut segments = Vec::new();
351        let segment_size = fft_size / 2;
352
353        let mut index = 0;
354        while index < buffer.len() {
355            let mut time_domain = vec![0.0_f32; fft_size];
356            for (offset, sample) in buffer
357                .iter()
358                .skip(index)
359                .take(segment_size)
360                .enumerate()
361            {
362                time_domain[offset] = *sample;
363            }
364
365            let mut spectrum = vec![Complex { re: 0.0, im: 0.0 }; spectrum_len];
366            r2c
367                .process(&mut time_domain, &mut spectrum)
368                .expect("real FFT failed");
369            segments.push(spectrum);
370            index += segment_size;
371        }
372
373        segments
374    }
375
376    fn init_previous_frame_q(
377        segment_count: usize,
378        spectrum_len: usize,
379    ) -> VecDeque<Vec<Complex<f32>>> {
380        let mut q = VecDeque::new();
381        for _ in 0..segment_count {
382            q.push_back(vec![Complex { re: 0.0, im: 0.0 }; spectrum_len]);
383        }
384        q
385    }
386}
387
388#[cfg(not(feature = "real-fft"))]
389pub use complex_fft::Convolver;
390
391#[cfg(feature = "real-fft")]
392pub use real_fft::Convolver;