1#![doc = include_str!("../README.md")]
2
3mod fft;
4mod utilities;
5use crate::fft::Fft;
6use crate::utilities::{
7 complex_multiply_accumulate, complex_size, copy_and_pad, next_power_of_2, sum,
8};
9use realfft::num_complex::Complex;
10use realfft::num_traits::Zero;
11use realfft::{FftError, FftNum};
12use rtsan_standalone::nonblocking;
13use thiserror::Error;
14
15#[derive(Error, Debug)]
16pub enum FFTConvolverError {
17 #[error("block size is not allowed to be zero")]
18 BlockSizeZero,
19 #[error("impulse response exceeds configured capacity")]
20 ImpulseResponseExceedsCapacity,
21 #[error("error in fft: {0}")]
22 Fft(#[from] FftError),
23}
24
25#[derive(Clone)]
41pub struct FFTConvolver<F: FftNum> {
42 ir_len: usize,
43 block_size: usize,
44 seg_size: usize,
45 seg_count: usize,
46 active_seg_count: usize,
47 fft_complex_size: usize,
48 segments: Vec<Vec<Complex<F>>>,
49 segments_ir: Vec<Vec<Complex<F>>>,
50 fft_buffer: Vec<F>,
51 fft: Fft<F>,
52 pre_multiplied: Vec<Complex<F>>,
53 conv: Vec<Complex<F>>,
54 overlap: Vec<F>,
55 current: usize,
56 input_buffer: Vec<F>,
57 input_buffer_fill: usize,
58}
59
60impl<F: FftNum> Default for FFTConvolver<F> {
61 fn default() -> Self {
62 Self {
63 ir_len: Default::default(),
64 block_size: Default::default(),
65 seg_size: Default::default(),
66 seg_count: Default::default(),
67 active_seg_count: Default::default(),
68 fft_complex_size: Default::default(),
69 segments: Default::default(),
70 segments_ir: Default::default(),
71 fft_buffer: Default::default(),
72 fft: Default::default(),
73 pre_multiplied: Default::default(),
74 conv: Default::default(),
75 overlap: Default::default(),
76 current: Default::default(),
77 input_buffer: Default::default(),
78 input_buffer_fill: Default::default(),
79 }
80 }
81}
82
83impl<F: FftNum> FFTConvolver<F> {
84 pub fn init(
113 &mut self,
114 block_size: usize,
115 impulse_response: &[F],
116 ) -> Result<(), FFTConvolverError> {
117 if block_size == 0 {
118 return Err(FFTConvolverError::BlockSizeZero);
119 }
120
121 self.ir_len = impulse_response.len();
122
123 if self.ir_len == 0 {
124 return Ok(());
125 }
126
127 self.block_size = next_power_of_2(block_size);
128 self.seg_size = 2 * self.block_size;
129 self.seg_count = (self.ir_len as f64 / self.block_size as f64).ceil() as usize;
130 self.active_seg_count = self.seg_count;
131 self.fft_complex_size = complex_size(self.seg_size);
132
133 self.fft.init(self.seg_size);
135 self.fft_buffer = vec![F::zero(); self.seg_size];
136
137 self.segments = vec![vec![Complex::zero(); self.fft_complex_size]; self.seg_count];
139
140 self.segments_ir = vec![vec![Complex::zero(); self.fft_complex_size]; self.seg_count];
142 for (i, segment) in self.segments_ir.iter_mut().enumerate() {
143 let remaining = self.ir_len - (i * self.block_size);
144 let size_copy = if remaining >= self.block_size {
145 self.block_size
146 } else {
147 remaining
148 };
149 copy_and_pad(
150 &mut self.fft_buffer,
151 &impulse_response[i * self.block_size..],
152 size_copy,
153 );
154 self.fft.forward(&mut self.fft_buffer, segment)?;
155 }
156
157 self.pre_multiplied = vec![Complex::zero(); self.fft_complex_size];
159 self.conv = vec![Complex::zero(); self.fft_complex_size];
160 self.overlap.resize(self.block_size, F::zero());
161
162 self.input_buffer = vec![F::zero(); self.block_size];
164 self.input_buffer_fill = 0;
165
166 self.current = 0;
168
169 Ok(())
170 }
171
172 #[nonblocking]
201 pub fn set_response(&mut self, impulse_response: &[F]) -> Result<(), FFTConvolverError> {
202 if impulse_response.len() > self.ir_len {
203 return Err(FFTConvolverError::ImpulseResponseExceedsCapacity);
204 }
205
206 self.fft_buffer.fill(F::zero());
207 self.conv.fill(Complex::zero());
208 self.pre_multiplied.fill(Complex::zero());
209 self.overlap.fill(F::zero());
210
211 self.active_seg_count =
212 (impulse_response.len() as f64 / self.block_size as f64).ceil() as usize;
213
214 for (i, segment) in self
216 .segments_ir
217 .iter_mut()
218 .enumerate()
219 .take(self.active_seg_count)
220 {
221 let remaining = impulse_response.len() - (i * self.block_size);
222 let size_copy = if remaining >= self.block_size {
223 self.block_size
224 } else {
225 remaining
226 };
227 copy_and_pad(
228 &mut self.fft_buffer,
229 &impulse_response[i * self.block_size..],
230 size_copy,
231 );
232 self.fft.forward(&mut self.fft_buffer, segment)?;
233 }
234
235 for segment in self.segments_ir.iter_mut().skip(self.active_seg_count) {
237 segment.fill(Complex::zero());
238 }
239
240 self.input_buffer.fill(F::zero());
241 self.input_buffer_fill = 0;
242 self.current = 0;
243
244 Ok(())
245 }
246
247 #[nonblocking]
279 pub fn process(&mut self, input: &[F], output: &mut [F]) -> Result<(), FFTConvolverError> {
280 if self.active_seg_count == 0 {
281 output.fill(F::zero());
282 return Ok(());
283 }
284
285 let mut processed = 0;
286 while processed < output.len() {
287 let input_buffer_was_empty = self.input_buffer_fill == 0;
288 let processing = std::cmp::min(
289 output.len() - processed,
290 self.block_size - self.input_buffer_fill,
291 );
292
293 let input_buffer_pos = self.input_buffer_fill;
294 self.input_buffer[input_buffer_pos..input_buffer_pos + processing]
295 .copy_from_slice(&input[processed..processed + processing]);
296
297 copy_and_pad(&mut self.fft_buffer, &self.input_buffer, self.block_size);
299 if let Err(err) = self
300 .fft
301 .forward(&mut self.fft_buffer, &mut self.segments[self.current])
302 {
303 output.fill(F::zero());
304 return Err(err.into());
305 }
306
307 if input_buffer_was_empty {
309 self.pre_multiplied.fill(Complex::zero());
310 for i in 1..self.active_seg_count {
311 let index_ir = i;
312 let index_audio = (self.current + i) % self.active_seg_count;
313 complex_multiply_accumulate(
314 &mut self.pre_multiplied,
315 &self.segments_ir[index_ir],
316 &self.segments[index_audio],
317 );
318 }
319 }
320 self.conv.copy_from_slice(&self.pre_multiplied);
321 complex_multiply_accumulate(
322 &mut self.conv,
323 &self.segments[self.current],
324 &self.segments_ir[0],
325 );
326
327 if let Err(err) = self.fft.inverse(&mut self.conv, &mut self.fft_buffer) {
329 output.fill(F::zero());
330 return Err(err.into());
331 }
332
333 sum(
335 &mut output[processed..processed + processing],
336 &self.fft_buffer[input_buffer_pos..input_buffer_pos + processing],
337 &self.overlap[input_buffer_pos..input_buffer_pos + processing],
338 );
339
340 self.input_buffer_fill += processing;
342 if self.input_buffer_fill == self.block_size {
343 self.input_buffer.fill(F::zero());
345 self.input_buffer_fill = 0;
346 self.overlap
348 .copy_from_slice(&self.fft_buffer[self.block_size..self.block_size * 2]);
349
350 self.current = if self.current > 0 {
352 self.current - 1
353 } else {
354 self.active_seg_count - 1
355 };
356 }
357 processed += processing;
358 }
359 Ok(())
360 }
361
362 #[nonblocking]
397 pub fn reset(&mut self) {
398 self.input_buffer.fill(F::zero());
399 self.input_buffer_fill = 0;
400
401 self.fft_buffer.fill(F::zero());
402 for segment in &mut self.segments {
403 segment.fill(Complex::zero());
404 }
405
406 self.conv.fill(Complex::zero());
407 self.pre_multiplied.fill(Complex::zero());
408
409 self.overlap.fill(F::zero());
410 self.current = 0;
411 }
412}
413
414#[cfg(test)]
416mod tests {
417 use crate::{FFTConvolver, FFTConvolverError};
418
419 #[test]
420 fn init_test() {
421 let mut convolver = FFTConvolver::default();
422 let ir = vec![1., 0., 0., 0.];
423 convolver.init(10, &ir).unwrap();
424
425 assert_eq!(convolver.ir_len, 4);
426 assert_eq!(convolver.block_size, 16);
427 assert_eq!(convolver.seg_size, 32);
428 assert_eq!(convolver.seg_count, 1);
429 assert_eq!(convolver.active_seg_count, 1);
430 assert_eq!(convolver.fft_complex_size, 17);
431
432 assert_eq!(convolver.segments.len(), 1);
433 assert_eq!(convolver.segments.first().unwrap().len(), 17);
434 for seg in &convolver.segments {
435 for num in seg {
436 assert_eq!(num.re, 0.);
437 assert_eq!(num.im, 0.);
438 }
439 }
440
441 assert_eq!(convolver.segments_ir.len(), 1);
442 assert_eq!(convolver.segments_ir.first().unwrap().len(), 17);
443 for seg in &convolver.segments_ir {
444 for num in seg {
445 assert_eq!(num.re, 1.);
446 assert_eq!(num.im, 0.);
447 }
448 }
449
450 assert_eq!(convolver.fft_buffer.len(), 32);
451 assert_eq!(*convolver.fft_buffer.first().unwrap(), 1.);
452 for i in 1..convolver.fft_buffer.len() {
453 assert_eq!(convolver.fft_buffer[i], 0.);
454 }
455
456 assert_eq!(convolver.pre_multiplied.len(), 17);
457 for num in &convolver.pre_multiplied {
458 assert_eq!(num.re, 0.);
459 assert_eq!(num.im, 0.);
460 }
461
462 assert_eq!(convolver.conv.len(), 17);
463 for num in &convolver.conv {
464 assert_eq!(num.re, 0.);
465 assert_eq!(num.im, 0.);
466 }
467
468 assert_eq!(convolver.overlap.len(), 16);
469 for num in &convolver.overlap {
470 assert_eq!(*num, 0.);
471 }
472
473 assert_eq!(convolver.input_buffer.len(), 16);
474 for num in &convolver.input_buffer {
475 assert_eq!(*num, 0.);
476 }
477
478 assert_eq!(convolver.input_buffer_fill, 0);
479 }
480
481 #[test]
482 fn process_test() {
483 let mut convolver = FFTConvolver::<f32>::default();
484 let ir = vec![1., 0., 0., 0.];
485 convolver.init(2, &ir).unwrap();
486
487 let input = vec![0., 1., 2., 3.];
488 let mut output = vec![0.; 4];
489 convolver.process(&input, &mut output).unwrap();
490
491 for i in 0..output.len() {
492 assert_eq!(input[i], output[i]);
493 }
494 }
495
496 #[test]
497 fn reset_test() {
498 let ir = vec![0.5, 0.3, 0.2, 0.1];
500 let block_size = 4;
501
502 let mut convolver1 = FFTConvolver::<f32>::default();
504 convolver1.init(block_size, &ir).unwrap();
505
506 let history_input = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
508 let mut history_output = vec![0.0; 8];
509 convolver1
510 .process(&history_input, &mut history_output)
511 .unwrap();
512
513 convolver1.reset();
515
516 let test_input = vec![1.0, 1.0, 1.0, 1.0];
518 let mut output1 = vec![0.0; 4];
519 convolver1.process(&test_input, &mut output1).unwrap();
520
521 let mut convolver2 = FFTConvolver::<f32>::default();
523 convolver2.init(block_size, &ir).unwrap();
524 let mut output2 = vec![0.0; 4];
525 convolver2.process(&test_input, &mut output2).unwrap();
526
527 for i in 0..output1.len() {
529 assert!(
530 (output1[i] - output2[i]).abs() < 1e-5,
531 "Mismatch at index {}: cleared convolver produced {}, fresh convolver produced {}",
532 i,
533 output1[i],
534 output2[i]
535 );
536 }
537 }
538
539 #[test]
540 fn reset_preserves_configuration() {
541 let ir = vec![0.5, 0.3, 0.2, 0.1];
543 let block_size = 4;
544
545 let mut convolver = FFTConvolver::<f32>::default();
546 convolver.init(block_size, &ir).unwrap();
547
548 let ir_len = convolver.ir_len;
549 let block_size_actual = convolver.block_size;
550 let seg_size = convolver.seg_size;
551 let seg_count = convolver.seg_count;
552
553 let input = vec![1.0, 2.0, 3.0, 4.0];
555 let mut output = vec![0.0; 4];
556 convolver.process(&input, &mut output).unwrap();
557
558 convolver.reset();
560
561 assert_eq!(convolver.ir_len, ir_len);
563 assert_eq!(convolver.block_size, block_size_actual);
564 assert_eq!(convolver.seg_size, seg_size);
565 assert_eq!(convolver.seg_count, seg_count);
566 }
567
568 #[test]
569 fn set_response_equals_init() {
570 let ir1 = vec![0.5, 0.3, 0.2, 0.1];
572 let ir2 = vec![0.8, 0.6, 0.4, 0.2];
573 let block_size = 4;
574
575 let mut convolver1 = FFTConvolver::<f32>::default();
577 convolver1.init(block_size, &ir1).unwrap();
578 convolver1.set_response(&ir2).unwrap();
579
580 let mut convolver2 = FFTConvolver::<f32>::default();
582 convolver2.init(block_size, &ir2).unwrap();
583
584 let input = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0];
586 let mut output1 = vec![0.0; 8];
587 let mut output2 = vec![0.0; 8];
588
589 convolver1.process(&input, &mut output1).unwrap();
590 convolver2.process(&input, &mut output2).unwrap();
591
592 for i in 0..output1.len() {
594 assert!(
595 (output1[i] - output2[i]).abs() < 1e-5,
596 "Mismatch at index {}: set_response produced {}, init produced {}",
597 i,
598 output1[i],
599 output2[i]
600 );
601 }
602 }
603
604 #[test]
605 fn set_response_with_shorter_ir() {
606 let ir1 = vec![0.5, 0.3, 0.2, 0.1, 0.05, 0.02];
608 let ir2 = vec![0.8, 0.6, 0.4];
609 let block_size = 4;
610
611 let mut convolver1 = FFTConvolver::<f32>::default();
613 convolver1.init(block_size, &ir1).unwrap();
614 convolver1.set_response(&ir2).unwrap();
615
616 let mut convolver2 = FFTConvolver::<f32>::default();
618 convolver2.init(block_size, &ir2).unwrap();
619
620 let input = vec![1.0, 1.0, 1.0, 1.0];
622 let mut output1 = vec![0.0; 4];
623 let mut output2 = vec![0.0; 4];
624
625 convolver1.process(&input, &mut output1).unwrap();
626 convolver2.process(&input, &mut output2).unwrap();
627
628 for i in 0..output1.len() {
630 assert!(
631 (output1[i] - output2[i]).abs() < 1e-5,
632 "Mismatch at index {}: set_response produced {}, init produced {}",
633 i,
634 output1[i],
635 output2[i]
636 );
637 }
638 }
639
640 #[test]
641 fn set_response_too_long_returns_error() {
642 let ir1 = vec![0.5, 0.3, 0.2, 0.1];
644 let ir2 = vec![0.8, 0.6, 0.4, 0.2, 0.1, 0.05];
645 let block_size = 4;
646
647 let mut convolver = FFTConvolver::<f32>::default();
648 convolver.init(block_size, &ir1).unwrap();
649
650 let result = convolver.set_response(&ir2);
652 assert!(result.is_err());
653 assert!(matches!(
654 result.unwrap_err(),
655 FFTConvolverError::ImpulseResponseExceedsCapacity
656 ));
657 }
658
659 #[test]
660 fn test_zero_latency() {
661 let mut convolver = FFTConvolver::<f32>::default();
664
665 let ir = vec![0.5, 0.3, 0.2, 0.1];
667 convolver.init(4, &ir).unwrap();
668
669 let mut input = vec![0.0; 16];
671 input[0] = 1.0; let mut output = vec![0.0; 16];
674 convolver.process(&input, &mut output).unwrap();
675
676 assert!(
679 output[0].abs() > 0.0,
680 "Output[0] should be non-zero, indicating zero latency. Got: {}",
681 output[0]
682 );
683
684 assert!(
686 (output[0] - 0.5).abs() < 1e-5,
687 "output[0] should be 0.5, got {}",
688 output[0]
689 );
690 assert!(
691 (output[1] - 0.3).abs() < 1e-5,
692 "output[1] should be 0.3, got {}",
693 output[1]
694 );
695 assert!(
696 (output[2] - 0.2).abs() < 1e-5,
697 "output[2] should be 0.2, got {}",
698 output[2]
699 );
700 assert!(
701 (output[3] - 0.1).abs() < 1e-5,
702 "output[3] should be 0.1, got {}",
703 output[3]
704 );
705 }
706}