proteus_lib/dsp/effects/convolution_reverb/
convolution.rs1#[cfg(all(feature = "real-fft", feature = "complex-fft"))]
8compile_error!("Enable only one of: real-fft or complex-fft.");
9
10#[cfg(not(any(feature = "real-fft", feature = "complex-fft")))]
11compile_error!("Enable one of: real-fft or complex-fft.");
12
13#[cfg(feature = "complex-fft")]
14mod complex_fft {
15 use std::collections::VecDeque;
16 use std::sync::Arc;
17
18 use rustfft::{num_complex::Complex, Fft, FftPlanner};
19
20 #[derive(Clone)]
24 pub struct Convolver {
25 pub fft_size: usize,
26 ir_segments: Vec<Vec<Complex<f32>>>,
27 previous_frame_q: VecDeque<Vec<Complex<f32>>>,
28 pub previous_tail: Vec<f32>,
29 pending_output: Vec<f32>,
30 fft_processor: Arc<dyn Fft<f32>>,
31 ifft_processor: Arc<dyn Fft<f32>>,
32 }
33
34 impl Convolver {
35 pub fn new(ir_signal: &[f32], fft_size: usize) -> Self {
37 let mut planner = FftPlanner::<f32>::new();
38 let fft_processor = planner.plan_fft_forward(fft_size);
39 let ifft_processor = planner.plan_fft_inverse(fft_size);
40
41 let ir_segments = segment_buffer(ir_signal, fft_size, &fft_processor);
42 let segment_count = ir_segments.len();
43 Self {
44 fft_size,
45 ir_segments,
46 fft_processor,
47 ifft_processor,
48 previous_frame_q: init_previous_frame_q(segment_count, fft_size),
49 previous_tail: init_previous_tail(fft_size / 2),
50 pending_output: Vec::new(),
51 }
52 }
53
54 pub fn process(&mut self, input_buffer: &[f32]) -> Vec<f32> {
59 let io_len = input_buffer.len();
60 let segment_size = self.fft_size / 2;
61 let input_segments = segment_buffer(input_buffer, self.fft_size, &self.fft_processor);
62
63 let mut output: Vec<f32> = Vec::with_capacity(io_len);
64 let norm = self.fft_size as f32;
65
66 if !self.pending_output.is_empty() {
67 let take = io_len.min(self.pending_output.len());
68 output.extend_from_slice(&self.pending_output[..take]);
69 self.pending_output.drain(0..take);
70 }
71
72 for segment in input_segments {
73 self.previous_frame_q.push_front(segment);
74 self.previous_frame_q.pop_back();
75
76 let mut convolved = self.convolve_frame();
77 self.ifft_processor.process(&mut convolved);
78
79 let mut time_domain: Vec<f32> = Vec::with_capacity(self.fft_size);
80 for sample in convolved {
81 time_domain.push(sample.re / norm);
82 }
83
84 for i in 0..segment_size {
85 if let Some(sample) = time_domain.get_mut(i) {
86 *sample += self.previous_tail[i];
87 }
88 }
89
90 self.previous_tail = time_domain[segment_size..self.fft_size].to_vec();
91 let remaining = io_len.saturating_sub(output.len());
92 if remaining == 0 {
93 self.pending_output
94 .extend_from_slice(&time_domain[0..segment_size]);
95 continue;
96 }
97 if remaining >= segment_size {
98 output.extend_from_slice(&time_domain[0..segment_size]);
99 } else {
100 output.extend_from_slice(&time_domain[0..remaining]);
101 self.pending_output
102 .extend_from_slice(&time_domain[remaining..segment_size]);
103 }
104 }
105
106 output
107 }
108
109 pub fn clear_state(&mut self) {
111 for frame in &mut self.previous_frame_q {
112 for sample in frame.iter_mut() {
113 sample.re = 0.0;
114 sample.im = 0.0;
115 }
116 }
117 self.previous_tail.fill(0.0);
118 self.pending_output.clear();
119 }
120
121 fn convolve_frame(&mut self) -> Vec<Complex<f32>> {
122 let mut convolved = vec![Complex { re: 0.0, im: 0.0 }; self.fft_size];
123
124 for i in 0..self.ir_segments.len() {
125 add_frames(
126 &mut convolved,
127 mult_frames(&self.previous_frame_q[i], &self.ir_segments[i]),
128 );
129 }
130 convolved
131 }
132 }
133
134 pub fn add_frames(f1: &mut [Complex<f32>], f2: Vec<Complex<f32>>) {
136 for (sample1, sample2) in f1.iter_mut().zip(f2) {
137 sample1.re = sample1.re + sample2.re;
138 sample1.im = sample1.im + sample2.im;
139 }
140 }
141
142 pub fn mult_frames(f1: &[Complex<f32>], f2: &[Complex<f32>]) -> Vec<Complex<f32>> {
144 let mut out: Vec<Complex<f32>> = Vec::new();
145 for (sample1, sample2) in f1.iter().zip(f2) {
146 out.push(Complex {
147 re: (sample1.re * sample2.re) - (sample1.im * sample2.im),
148 im: (sample1.im * sample2.re) + (sample1.re * sample2.im),
149 });
150 }
151 out
152 }
153
154 pub fn init_previous_tail(size: usize) -> Vec<f32> {
156 let mut tail = Vec::new();
157 for _ in 0..size {
158 tail.push(0.0);
159 }
160 tail
161 }
162
163 pub fn segment_buffer(
165 buffer: &[f32],
166 fft_size: usize,
167 fft_processor: &Arc<dyn Fft<f32>>,
168 ) -> Vec<Vec<Complex<f32>>> {
169 let mut segments = Vec::new();
170 let segment_size = fft_size / 2;
171
172 let mut index = 0;
173 while index < buffer.len() {
174 let mut new_segment: Vec<Complex<f32>> = Vec::new();
175 for i in index..index + segment_size {
176 match buffer.get(i) {
177 Some(sample) => new_segment.push(Complex {
178 re: *sample,
179 im: 0.0,
180 }),
181 None => continue,
182 }
183 }
184 while new_segment.len() < fft_size {
185 new_segment.push(Complex { re: 0.0, im: 0.0 });
186 }
187 fft_processor.process(&mut new_segment);
188 segments.push(new_segment);
189 index += segment_size;
190 }
191
192 segments
193 }
194
195 pub fn init_previous_frame_q(
197 segment_count: usize,
198 fft_size: usize,
199 ) -> VecDeque<Vec<Complex<f32>>> {
200 let mut q = VecDeque::new();
201 for _ in 0..segment_count {
202 let mut empty = Vec::new();
203 for _ in 0..fft_size {
204 empty.push(Complex { re: 0.0, im: 0.0 });
205 }
206 q.push_back(empty);
207 }
208 q
209 }
210}
211
212#[cfg(feature = "real-fft")]
213mod real_fft {
214 use std::collections::VecDeque;
215 use std::sync::Arc;
216
217 use realfft::{num_complex::Complex, ComplexToReal, RealFftPlanner, RealToComplex};
218
219 #[derive(Clone)]
221 pub struct Convolver {
222 pub fft_size: usize,
223 ir_segments: Vec<Vec<Complex<f32>>>,
224 previous_frame_q: VecDeque<Vec<Complex<f32>>>,
225 pub previous_tail: Vec<f32>,
226 pending_output: Vec<f32>,
227 r2c: Arc<dyn RealToComplex<f32>>,
228 c2r: Arc<dyn ComplexToReal<f32>>,
229 }
230
231 impl Convolver {
232 pub fn new(ir_signal: &[f32], fft_size: usize) -> Self {
234 let mut planner = RealFftPlanner::<f32>::new();
235 let r2c = planner.plan_fft_forward(fft_size);
236 let c2r = planner.plan_fft_inverse(fft_size);
237 let spectrum_len = (fft_size / 2) + 1;
238
239 let ir_segments = segment_buffer(ir_signal, fft_size, &r2c, spectrum_len);
240 let segment_count = ir_segments.len();
241 Self {
242 fft_size,
243 ir_segments,
244 r2c,
245 c2r,
246 previous_frame_q: init_previous_frame_q(segment_count, spectrum_len),
247 previous_tail: init_previous_tail(fft_size / 2),
248 pending_output: Vec::new(),
249 }
250 }
251
252 pub fn process(&mut self, input_buffer: &[f32]) -> Vec<f32> {
257 let io_len = input_buffer.len();
258 let segment_size = self.fft_size / 2;
259 let spectrum_len = self.ir_segments.first().map(|seg| seg.len()).unwrap_or(0);
260 let input_segments =
261 segment_buffer(input_buffer, self.fft_size, &self.r2c, spectrum_len);
262
263 let mut output: Vec<f32> = Vec::with_capacity(io_len);
264 let norm = self.fft_size as f32;
265 let spectrum_len = self.ir_segments.first().map(|seg| seg.len()).unwrap_or(0);
266
267 if !self.pending_output.is_empty() {
268 let take = io_len.min(self.pending_output.len());
269 output.extend_from_slice(&self.pending_output[..take]);
270 self.pending_output.drain(0..take);
271 }
272
273 for segment in input_segments {
274 self.previous_frame_q.push_front(segment);
275 self.previous_frame_q.pop_back();
276
277 let mut convolved = vec![Complex { re: 0.0, im: 0.0 }; spectrum_len];
278 for i in 0..self.ir_segments.len() {
279 add_frames(
280 &mut convolved,
281 mult_frames(&self.previous_frame_q[i], &self.ir_segments[i]),
282 );
283 }
284
285 let mut time_domain = vec![0.0_f32; self.fft_size];
286 self.c2r
287 .process(&mut convolved, &mut time_domain)
288 .expect("real IFFT failed");
289
290 for sample in &mut time_domain {
291 *sample /= norm;
292 }
293
294 for i in 0..segment_size {
295 time_domain[i] += self.previous_tail[i];
296 }
297
298 self.previous_tail = time_domain[segment_size..self.fft_size].to_vec();
299 let remaining = io_len.saturating_sub(output.len());
300 if remaining == 0 {
301 self.pending_output
302 .extend_from_slice(&time_domain[0..segment_size]);
303 continue;
304 }
305 if remaining >= segment_size {
306 output.extend_from_slice(&time_domain[0..segment_size]);
307 } else {
308 output.extend_from_slice(&time_domain[0..remaining]);
309 self.pending_output
310 .extend_from_slice(&time_domain[remaining..segment_size]);
311 }
312 }
313
314 output
315 }
316
317 pub fn clear_state(&mut self) {
319 for frame in &mut self.previous_frame_q {
320 for sample in frame.iter_mut() {
321 sample.re = 0.0;
322 sample.im = 0.0;
323 }
324 }
325 self.previous_tail.fill(0.0);
326 self.pending_output.clear();
327 }
328 }
329
330 fn add_frames(f1: &mut [Complex<f32>], f2: Vec<Complex<f32>>) {
331 for (sample1, sample2) in f1.iter_mut().zip(f2) {
332 sample1.re = sample1.re + sample2.re;
333 sample1.im = sample1.im + sample2.im;
334 }
335 }
336
337 fn mult_frames(f1: &[Complex<f32>], f2: &[Complex<f32>]) -> Vec<Complex<f32>> {
338 let mut out: Vec<Complex<f32>> = Vec::with_capacity(f1.len());
339 for (sample1, sample2) in f1.iter().zip(f2) {
340 out.push(Complex {
341 re: (sample1.re * sample2.re) - (sample1.im * sample2.im),
342 im: (sample1.im * sample2.re) + (sample1.re * sample2.im),
343 });
344 }
345 out
346 }
347
348 fn init_previous_tail(size: usize) -> Vec<f32> {
349 vec![0.0; size]
350 }
351
352 fn segment_buffer(
353 buffer: &[f32],
354 fft_size: usize,
355 r2c: &Arc<dyn RealToComplex<f32>>,
356 spectrum_len: usize,
357 ) -> Vec<Vec<Complex<f32>>> {
358 let mut segments = Vec::new();
359 let segment_size = fft_size / 2;
360
361 let mut index = 0;
362 while index < buffer.len() {
363 let mut time_domain = vec![0.0_f32; fft_size];
364 for (offset, sample) in buffer.iter().skip(index).take(segment_size).enumerate() {
365 time_domain[offset] = *sample;
366 }
367
368 let mut spectrum = vec![Complex { re: 0.0, im: 0.0 }; spectrum_len];
369 r2c.process(&mut time_domain, &mut spectrum)
370 .expect("real FFT failed");
371 segments.push(spectrum);
372 index += segment_size;
373 }
374
375 segments
376 }
377
378 fn init_previous_frame_q(
379 segment_count: usize,
380 spectrum_len: usize,
381 ) -> VecDeque<Vec<Complex<f32>>> {
382 let mut q = VecDeque::new();
383 for _ in 0..segment_count {
384 q.push_back(vec![Complex { re: 0.0, im: 0.0 }; spectrum_len]);
385 }
386 q
387 }
388}
389
390#[cfg(feature = "complex-fft")]
391pub use complex_fft::Convolver;
392
393#[cfg(feature = "real-fft")]
394pub use real_fft::Convolver;