mod fft;
mod utilities;
use crate::fft::Fft;
use crate::utilities::{
complex_multiply_accumulate, complex_size, copy_and_pad, next_power_of_2, sum,
};
use realfft::FftError;
use rustfft::num_complex::Complex;
use thiserror::Error;
#[derive(Error, Debug)]
pub enum FFTConvolverInitError {
#[error("block size is not allowed to be zero")]
BlockSizeZero(),
#[error("fft error")]
Fft(#[from] FftError),
}
#[derive(Error, Debug)]
pub enum FFTConvolverProcessError {
#[error("fft error")]
Fft(#[from] FftError),
}
#[derive(Debug)]
pub struct FFTConvolver {
ir_len: usize,
block_size: usize,
seg_size: usize,
seg_count: usize,
active_seg_count: usize,
fft_complex_size: usize,
segments: Vec<Vec<Complex<f32>>>,
segments_ir: Vec<Vec<Complex<f32>>>,
fft_buffer: Vec<f32>,
fft: Fft,
pre_multiplied: Vec<Complex<f32>>,
conv: Vec<Complex<f32>>,
overlap: Vec<f32>,
current: usize,
input_buffer: Vec<f32>,
input_buffer_fill: usize,
}
impl FFTConvolver {
pub fn default() -> Self {
Self {
ir_len: 0,
block_size: 0,
seg_size: 0,
seg_count: 0,
active_seg_count: 0,
fft_complex_size: 0,
segments: Vec::new(),
segments_ir: Vec::new(),
fft_buffer: Vec::new(),
fft: Fft::default(),
pre_multiplied: Vec::new(),
conv: Vec::new(),
overlap: Vec::new(),
current: 0,
input_buffer: Vec::new(),
input_buffer_fill: 0,
}
}
pub fn reset(&mut self) {
*self = Self::default();
}
pub fn init(
&mut self,
block_size: usize,
impulse_response: &[f32],
) -> Result<(), FFTConvolverInitError> {
self.reset();
if block_size == 0 {
return Err(FFTConvolverInitError::BlockSizeZero());
}
self.ir_len = impulse_response.len();
if self.ir_len == 0 {
return Ok(());
}
self.block_size = next_power_of_2(block_size);
self.seg_size = 2 * self.block_size;
self.seg_count = (self.ir_len as f64 / self.block_size as f64).ceil() as usize;
self.active_seg_count = self.seg_count;
self.fft_complex_size = complex_size(self.seg_size);
self.fft.init(self.seg_size);
self.fft_buffer.resize(self.seg_size, 0.);
self.segments.resize(
self.seg_count,
vec![Complex::new(0., 0.); self.fft_complex_size],
);
for i in 0..self.seg_count {
let mut segment = vec![Complex::new(0., 0.); self.fft_complex_size];
let remaining = self.ir_len - (i * self.block_size);
let size_copy = if remaining >= self.block_size {
self.block_size
} else {
remaining
};
copy_and_pad(
&mut self.fft_buffer,
&impulse_response[i * self.block_size..],
size_copy,
);
self.fft.forward(&mut self.fft_buffer, &mut segment)?;
self.segments_ir.push(segment);
}
self.pre_multiplied
.resize(self.fft_complex_size, Complex::new(0., 0.));
self.conv
.resize(self.fft_complex_size, Complex::new(0., 0.));
self.overlap.resize(self.block_size, 0.);
self.input_buffer.resize(self.block_size, 0.);
self.input_buffer_fill = 0;
self.current = 0;
Ok(())
}
pub fn process(
&mut self,
input: &[f32],
output: &mut [f32],
) -> Result<(), FFTConvolverProcessError> {
if self.active_seg_count == 0 {
output.fill(0.);
return Ok(());
}
let mut processed = 0;
while processed < output.len() {
let input_buffer_was_empty = self.input_buffer_fill == 0;
let processing = std::cmp::min(
output.len() - processed,
self.block_size - self.input_buffer_fill,
);
let input_buffer_pos = self.input_buffer_fill;
self.input_buffer[input_buffer_pos..input_buffer_pos + processing]
.clone_from_slice(&input[processed..processed + processing]);
copy_and_pad(&mut self.fft_buffer, &self.input_buffer, self.block_size);
if let Err(err) = self
.fft
.forward(&mut self.fft_buffer, &mut self.segments[self.current])
{
output.fill(0.);
return Err(err.into());
}
if input_buffer_was_empty {
self.pre_multiplied.fill(Complex { re: 0., im: 0. });
for i in 1..self.active_seg_count {
let index_ir = i;
let index_audio = (self.current + i) % self.active_seg_count;
complex_multiply_accumulate(
&mut self.pre_multiplied,
&self.segments_ir[index_ir],
&self.segments[index_audio],
);
}
}
self.conv.clone_from_slice(&self.pre_multiplied);
complex_multiply_accumulate(
&mut self.conv,
&self.segments[self.current],
&self.segments_ir[0],
);
if let Err(err) = self.fft.inverse(&mut self.conv, &mut self.fft_buffer) {
output.fill(0.);
return Err(err.into());
}
sum(
&mut output[processed..processed + processing],
&self.fft_buffer[input_buffer_pos..input_buffer_pos + processing],
&self.overlap[input_buffer_pos..input_buffer_pos + processing],
);
self.input_buffer_fill += processing;
if self.input_buffer_fill == self.block_size {
self.input_buffer.fill(0.);
self.input_buffer_fill = 0;
self.overlap
.clone_from_slice(&self.fft_buffer[self.block_size..self.block_size * 2]);
self.current = if self.current > 0 {
self.current - 1
} else {
self.active_seg_count - 1
};
}
processed += processing;
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use crate::FFTConvolver;
#[test]
fn init_test() {
let mut convolver = FFTConvolver::default();
let ir = vec![1., 0., 0., 0.];
convolver.init(10, &ir).unwrap();
assert_eq!(convolver.ir_len, 4);
assert_eq!(convolver.block_size, 16);
assert_eq!(convolver.seg_size, 32);
assert_eq!(convolver.seg_count, 1);
assert_eq!(convolver.active_seg_count, 1);
assert_eq!(convolver.fft_complex_size, 17);
assert_eq!(convolver.segments.len(), 1);
assert_eq!(convolver.segments.first().unwrap().len(), 17);
for seg in &convolver.segments {
for num in seg {
assert_eq!(num.re, 0.);
assert_eq!(num.im, 0.);
}
}
assert_eq!(convolver.segments_ir.len(), 1);
assert_eq!(convolver.segments_ir.first().unwrap().len(), 17);
for seg in &convolver.segments_ir {
for num in seg {
assert_eq!(num.re, 1.);
assert_eq!(num.im, 0.);
}
}
assert_eq!(convolver.fft_buffer.len(), 32);
assert_eq!(*convolver.fft_buffer.first().unwrap(), 1.);
for i in 1..convolver.fft_buffer.len() {
assert_eq!(convolver.fft_buffer[i], 0.);
}
assert_eq!(convolver.pre_multiplied.len(), 17);
for num in &convolver.pre_multiplied {
assert_eq!(num.re, 0.);
assert_eq!(num.im, 0.);
}
assert_eq!(convolver.conv.len(), 17);
for num in &convolver.conv {
assert_eq!(num.re, 0.);
assert_eq!(num.im, 0.);
}
assert_eq!(convolver.overlap.len(), 16);
for num in &convolver.overlap {
assert_eq!(*num, 0.);
}
assert_eq!(convolver.input_buffer.len(), 16);
for num in &convolver.input_buffer {
assert_eq!(*num, 0.);
}
assert_eq!(convolver.input_buffer_fill, 0);
}
#[test]
fn process_test() {
let mut convolver = FFTConvolver::default();
let ir = vec![1., 0., 0., 0.];
convolver.init(2, &ir).unwrap();
let input = vec![0., 1., 2., 3.];
let mut output = vec![0.; 4];
convolver.process(&input, &mut output).unwrap();
for i in 0..output.len() {
assert_eq!(input[i], output[i]);
}
}
}