1use rustfft::{num_complex::Complex, FftPlanner};
6use std::sync::Arc;
7
8pub struct FFTConvolver {
13 fft_size: usize,
14 impulse_response_fft: Vec<Vec<Complex<f64>>>, overlap_buffers: Vec<Vec<f64>>, channels: usize,
17 ir_len: usize,
18 fft_forward: Arc<dyn rustfft::Fft<f64>>,
20 fft_inverse: Arc<dyn rustfft::Fft<f64>>,
21 scratch_complex: Vec<Complex<f64>>,
23}
24
25impl Clone for FFTConvolver {
26 fn clone(&self) -> Self {
27 Self {
28 fft_size: self.fft_size,
29 impulse_response_fft: self.impulse_response_fft.clone(),
30 overlap_buffers: self.overlap_buffers.clone(),
31 channels: self.channels,
32 ir_len: self.ir_len,
33 fft_forward: Arc::clone(&self.fft_forward),
34 fft_inverse: Arc::clone(&self.fft_inverse),
35 scratch_complex: self.scratch_complex.clone(),
36 }
37 }
38}
39
40impl FFTConvolver {
41 pub fn new(ir_data: &[f64], channels: usize) -> Self {
47 let ir_len_total = ir_data.len();
48 let ir_len_per_ch = ir_len_total / channels;
49
50 let mut fft_size = 1;
52 while fft_size < (ir_len_per_ch * 2) {
53 fft_size <<= 1;
54 }
55
56 let mut planner = FftPlanner::new();
57 let fft = planner.plan_fft_forward(fft_size);
58
59 let fft_forward = planner.plan_fft_forward(fft_size);
61 let fft_inverse = planner.plan_fft_inverse(fft_size);
62
63 let mut ir_ffts = Vec::with_capacity(channels);
64 let mut overlap_bufs = Vec::with_capacity(channels);
65
66 for ch in 0..channels {
67 let mut buffer = vec![Complex::new(0.0, 0.0); fft_size];
68 for i in 0..ir_len_per_ch {
70 buffer[i] = Complex::new(ir_data[i * channels + ch], 0.0);
71 }
72 fft.process(&mut buffer);
73 ir_ffts.push(buffer);
74 overlap_bufs.push(vec![0.0; ir_len_per_ch - 1]);
75 }
76
77 let scratch_complex = vec![Complex::new(0.0, 0.0); fft_size];
79
80 FFTConvolver {
81 fft_size,
82 impulse_response_fft: ir_ffts,
83 overlap_buffers: overlap_bufs,
84 channels,
85 ir_len: ir_len_per_ch,
86 fft_forward,
87 fft_inverse,
88 scratch_complex,
89 }
90 }
91
92 pub fn ir_length(&self) -> usize {
94 self.ir_len
95 }
96
97 pub fn fft_size(&self) -> usize {
99 self.fft_size
100 }
101
102 pub fn reset(&mut self) {
105 for overlap in &mut self.overlap_buffers {
106 overlap.fill(0.0);
107 }
108 }
109
110 #[inline]
111 #[allow(clippy::too_many_arguments)]
112 fn prepare_channel_chunk(
113 scratch: &mut [Complex<f64>],
114 overlap: &[f64],
115 input: &[f64],
116 channels: usize,
117 channel: usize,
118 processed_frames: usize,
119 chunk_len: usize,
120 ir_len: usize,
121 ) {
122 for i in 0..ir_len - 1 {
123 scratch[i] = Complex::new(overlap[i], 0.0);
124 }
125
126 for i in 0..chunk_len {
127 scratch[i + ir_len - 1] =
128 Complex::new(input[(processed_frames + i) * channels + channel], 0.0);
129 }
130 scratch[ir_len - 1 + chunk_len..].fill(Complex::new(0.0, 0.0));
131 }
132
133 #[inline]
134 fn update_channel_overlap(
135 overlap: &mut [f64],
136 input: &[f64],
137 channels: usize,
138 channel: usize,
139 processed_frames: usize,
140 chunk_len: usize,
141 ir_len: usize,
142 ) {
143 if chunk_len >= ir_len - 1 {
144 for i in 0..ir_len - 1 {
145 overlap[i] =
146 input[(processed_frames + chunk_len - (ir_len - 1) + i) * channels + channel];
147 }
148 } else {
149 let shift = chunk_len;
150 let keep = ir_len - 1 - shift;
151 overlap.copy_within(shift..shift + keep, 0);
152 for i in 0..shift {
153 overlap[keep + i] = input[(processed_frames + i) * channels + channel];
154 }
155 }
156 }
157
158 #[inline]
159 #[allow(clippy::too_many_arguments)]
160 fn write_channel_output(
161 scratch: &[Complex<f64>],
162 output: &mut [f64],
163 channels: usize,
164 channel: usize,
165 processed_frames: usize,
166 chunk_len: usize,
167 ir_len: usize,
168 inv_n: f64,
169 ) {
170 for i in 0..chunk_len {
171 output[(processed_frames + i) * channels + channel] =
172 scratch[i + ir_len - 1].re * inv_n;
173 }
174 }
175
176 #[inline]
177 fn process_channel_chunk_fft(&mut self, channel: usize) {
178 self.fft_forward.process(&mut self.scratch_complex);
179
180 let ir_fft = &self.impulse_response_fft[channel];
181 multiply_spectrum_in_place(&mut self.scratch_complex, ir_fft);
182
183 self.fft_inverse.process(&mut self.scratch_complex);
184 }
185
186 #[inline]
195 pub fn process_into(&mut self, input: &[f64], output: &mut [f64]) {
196 debug_assert_eq!(input.len(), output.len());
197
198 let channels = self.channels;
199 let total_frames = input.len() / channels;
200 let fft_size = self.fft_size;
201 let ir_len = self.ir_len;
202 let step_size = fft_size - ir_len + 1;
203 let inv_n = 1.0 / fft_size as f64;
204
205 output[total_frames * channels..].fill(0.0);
208
209 for ch in 0..channels {
210 let mut processed_frames = 0;
211
212 while processed_frames < total_frames {
213 let chunk_len = std::cmp::min(step_size, total_frames - processed_frames);
214
215 Self::prepare_channel_chunk(
216 &mut self.scratch_complex,
217 &self.overlap_buffers[ch],
218 input,
219 channels,
220 ch,
221 processed_frames,
222 chunk_len,
223 ir_len,
224 );
225 self.process_channel_chunk_fft(ch);
226 Self::write_channel_output(
227 &self.scratch_complex,
228 output,
229 channels,
230 ch,
231 processed_frames,
232 chunk_len,
233 ir_len,
234 inv_n,
235 );
236
237 Self::update_channel_overlap(
238 &mut self.overlap_buffers[ch],
239 input,
240 channels,
241 ch,
242 processed_frames,
243 chunk_len,
244 ir_len,
245 );
246
247 processed_frames += chunk_len;
248 }
249 }
250 }
251
252 pub fn process(&mut self, input: &[f64]) -> Vec<f64> {
256 let mut output = vec![0.0; input.len()];
257 self.process_into(input, &mut output);
258 output
259 }
260
261 #[inline]
269 pub fn process_inplace(&mut self, buf: &mut [f64]) {
270 let channels = self.channels;
275 let total_frames = buf.len() / channels;
276 let fft_size = self.fft_size;
277 let ir_len = self.ir_len;
278 let step_size = fft_size - ir_len + 1;
279 let inv_n = 1.0 / fft_size as f64;
280
281 for ch in 0..channels {
286 let mut processed_frames = 0;
287
288 while processed_frames < total_frames {
289 let chunk_len = std::cmp::min(step_size, total_frames - processed_frames);
290
291 Self::prepare_channel_chunk(
292 &mut self.scratch_complex,
293 &self.overlap_buffers[ch],
294 buf,
295 channels,
296 ch,
297 processed_frames,
298 chunk_len,
299 ir_len,
300 );
301 self.process_channel_chunk_fft(ch);
302
303 Self::update_channel_overlap(
307 &mut self.overlap_buffers[ch],
308 buf,
309 channels,
310 ch,
311 processed_frames,
312 chunk_len,
313 ir_len,
314 );
315
316 Self::write_channel_output(
318 &self.scratch_complex,
319 buf,
320 channels,
321 ch,
322 processed_frames,
323 chunk_len,
324 ir_len,
325 inv_n,
326 );
327
328 processed_frames += chunk_len;
329 }
330 }
331 }
332}
333
334#[inline]
335fn multiply_spectrum_in_place(samples: &mut [Complex<f64>], ir_fft: &[Complex<f64>]) {
336 for (sample, ir) in samples.iter_mut().zip(ir_fft) {
337 let re = sample.re * ir.re - sample.im * ir.im;
338 let im = sample.re * ir.im + sample.im * ir.re;
339 sample.re = re;
340 sample.im = im;
341 }
342}
343
344#[cfg(test)]
345mod tests {
346 use super::*;
347
348 #[test]
349 fn test_convolver_identity() {
350 let ir = vec![1.0, 0.0, 0.0, 0.0]; let mut conv = FFTConvolver::new(&ir, 1);
353
354 let input = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
355 let mut output = vec![0.0; input.len()];
356
357 conv.process_into(&input, &mut output);
358
359 for i in 0..input.len() {
361 assert!(
362 (output[i] - input[i]).abs() < 1e-10,
363 "Mismatch at {}: {} vs {}",
364 i,
365 output[i],
366 input[i]
367 );
368 }
369 }
370
371 #[test]
372 fn test_convolver_stereo() {
373 let ir = vec![1.0, 1.0, 0.0, 0.0]; let mut conv = FFTConvolver::new(&ir, 2);
376
377 let input = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
378 let mut output = vec![0.0; input.len()];
379
380 conv.process_into(&input, &mut output);
381
382 assert!(output.iter().any(|&x| x != 0.0));
384 }
385
386 #[test]
387 fn test_zero_allocation() {
388 let ir: Vec<f64> = (0..1024).map(|i| (i as f64 / 1024.0).sin()).collect();
389 let mut conv = FFTConvolver::new(&ir, 1);
390
391 let input = vec![0.5; 4096];
392 let mut output = vec![0.0; 4096];
393
394 for _ in 0..100 {
396 conv.process_into(&input, &mut output);
397 }
398
399 assert!(output.iter().any(|&x| x != 0.0));
401 }
402
403 #[test]
406 fn test_inplace_identity() {
407 let ir = vec![1.0, 0.0, 0.0, 0.0]; let mut conv = FFTConvolver::new(&ir, 1);
410
411 let original = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
412 let mut buf = original.clone();
413
414 conv.process_inplace(&mut buf);
415
416 for i in 0..original.len() {
417 assert!(
418 (buf[i] - original[i]).abs() < 1e-10,
419 "Inplace identity mismatch at {}: {} vs {}",
420 i,
421 buf[i],
422 original[i]
423 );
424 }
425 }
426
427 #[test]
428 fn test_inplace_matches_process_into() {
429 let ir: Vec<f64> = (0..32).map(|i| (i as f64 / 32.0).sin() * 0.1).collect();
431 let input: Vec<f64> = (0..256).map(|i| (i as f64 * 0.05).sin()).collect();
432
433 let mut conv1 = FFTConvolver::new(&ir, 1);
434 let mut conv2 = FFTConvolver::new(&ir, 1);
435
436 let mut output_into = vec![0.0; input.len()];
437 conv1.process_into(&input, &mut output_into);
438
439 let mut buf_inplace = input.clone();
440 conv2.process_inplace(&mut buf_inplace);
441
442 for i in 0..input.len() {
443 assert!(
444 (output_into[i] - buf_inplace[i]).abs() < 1e-10,
445 "Mismatch at {}: into={} vs inplace={}",
446 i,
447 output_into[i],
448 buf_inplace[i]
449 );
450 }
451 }
452
453 fn assert_processing_paths_equivalent(channels: usize, ir_frames: usize, input_frames: usize) {
454 let ir: Vec<f64> = (0..ir_frames * channels)
455 .map(|i| ((i + 1) as f64 * 0.17).sin() * 0.05)
456 .collect();
457 let input: Vec<f64> = (0..input_frames * channels)
458 .map(|i| ((i + 3) as f64 * 0.11).cos() * 0.5)
459 .collect();
460
461 let mut process_conv = FFTConvolver::new(&ir, channels);
462 let mut into_conv = FFTConvolver::new(&ir, channels);
463 let mut inplace_conv = FFTConvolver::new(&ir, channels);
464
465 let process_output = process_conv.process(&input);
466
467 let mut into_output = vec![f64::NAN; input.len()];
468 into_conv.process_into(&input, &mut into_output);
469
470 let mut inplace_output = input.clone();
471 inplace_conv.process_inplace(&mut inplace_output);
472
473 for i in 0..input.len() {
474 assert!(
475 (process_output[i] - into_output[i]).abs() < 1e-10,
476 "process/process_into mismatch at {i}: {} vs {}",
477 process_output[i],
478 into_output[i]
479 );
480 assert!(
481 (process_output[i] - inplace_output[i]).abs() < 1e-10,
482 "process/process_inplace mismatch at {i}: {} vs {}",
483 process_output[i],
484 inplace_output[i]
485 );
486 }
487 }
488
489 #[test]
490 fn test_processing_paths_equivalent_for_boundary_chunk_sizes() {
491 assert_processing_paths_equivalent(1, 8, 4);
492 assert_processing_paths_equivalent(2, 8, 8);
493 assert_processing_paths_equivalent(6, 8, 20);
494 }
495
496 #[test]
497 fn test_inplace_small_buffer() {
498 let ir = vec![1.0, 0.5, 0.25, 0.125, 0.0, 0.0, 0.0, 0.0]; let mut conv = FFTConvolver::new(&ir, 1);
501
502 let mut buf = vec![1.0, 0.0, 0.0, 0.0];
504 conv.process_inplace(&mut buf);
505
506 assert!((buf[0] - 1.0).abs() < 1e-10, "Expected 1.0, got {}", buf[0]);
509 assert!((buf[1] - 0.5).abs() < 1e-10, "Expected 0.5, got {}", buf[1]);
510 assert!(
511 (buf[2] - 0.25).abs() < 1e-10,
512 "Expected 0.25, got {}",
513 buf[2]
514 );
515 assert!(
516 (buf[3] - 0.125).abs() < 1e-10,
517 "Expected 0.125, got {}",
518 buf[3]
519 );
520 }
521
522 #[test]
523 fn test_inplace_stereo_identity() {
524 let ir = vec![1.0, 1.0, 0.0, 0.0]; let mut conv = FFTConvolver::new(&ir, 2);
527
528 let original = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]; let mut buf = original.clone();
530
531 conv.process_inplace(&mut buf);
532
533 for i in 0..original.len() {
534 assert!(
535 (buf[i] - original[i]).abs() < 1e-10,
536 "Stereo inplace identity mismatch at {}: {} vs {}",
537 i,
538 buf[i],
539 original[i]
540 );
541 }
542 }
543
544 #[test]
545 fn test_inplace_multi_chunk() {
546 let ir = vec![1.0, 0.5, 0.0, 0.0]; let mut conv = FFTConvolver::new(&ir, 1);
549
550 let mut buf1 = vec![1.0, 0.0, 0.0, 0.0];
551 conv.process_inplace(&mut buf1);
552
553 let mut buf2 = vec![0.0, 0.0, 0.0, 0.0];
555 conv.process_inplace(&mut buf2);
556
557 assert!((buf1[0] - 1.0).abs() < 1e-10);
559 assert!((buf1[1] - 0.5).abs() < 1e-10);
560 }
561}