#[cfg(feature = "std")]
use std::vec::Vec;
#[cfg(all(not(feature = "std"), feature = "alloc"))]
use alloc::vec::Vec;
use num_traits::Float;
use num_complex::Complex;
use crate::fft::SimpleFFT;
use num_traits::FromPrimitive;
const DIRECT_CONV_THRESHOLD: usize = 256;
pub struct Convolver<T: Float + FromPrimitive> {
ir: Vec<T>,
fft: Option<SimpleFFT<T>>,
fft_size: usize,
input_buffer: Vec<T>,
output_buffer: Vec<T>,
ir_fft: Vec<Complex<T>>,
input_pos: usize,
history: Vec<T>,
history_pos: usize,
use_fft: bool,
}
impl<T: Float + FromPrimitive> Convolver<T> {
pub fn new(ir: Vec<T>) -> Self {
if ir.is_empty() {
return Self {
ir,
fft: None,
fft_size: 0,
input_buffer: Vec::new(),
output_buffer: Vec::new(),
ir_fft: Vec::new(),
input_pos: 0,
history: Vec::new(),
history_pos: 0,
use_fft: false,
};
}
let use_fft = ir.len() > DIRECT_CONV_THRESHOLD;
if use_fft {
let fft_size = Self::next_power_of_two(ir.len() * 2);
let mut fft_instance = SimpleFFT::new(fft_size);
let mut ir_buffer = vec![Complex::new(T::zero(), T::zero()); fft_size];
for (i, &coeff) in ir.iter().enumerate() {
ir_buffer[i] = Complex::new(coeff, T::zero());
}
let mut ir_fft = vec![Complex::new(T::zero(), T::zero()); fft_size];
fft_instance.fft(&ir_buffer, &mut ir_fft);
let input_buffer = vec![T::zero(); fft_size];
let output_buffer = vec![T::zero(); fft_size];
Self {
ir: ir,
fft: Some(fft_instance),
fft_size,
input_buffer,
output_buffer,
ir_fft,
input_pos: 0,
history: Vec::new(),
history_pos: 0,
use_fft: true,
}
} else {
let history = vec![T::zero(); ir.len()];
Self {
ir,
fft: None,
fft_size: 0,
input_buffer: Vec::new(),
output_buffer: Vec::new(),
ir_fft: Vec::new(),
input_pos: 0,
history,
history_pos: 0,
use_fft: false,
}
}
}
pub fn process(&mut self, sample: T) -> T {
if self.ir.is_empty() {
return sample;
}
if self.use_fft {
self.process_fft(sample)
} else {
self.process_direct(sample)
}
}
pub fn process_block(&mut self, input: &[T]) -> Vec<T> {
input.iter().map(|&s| self.process(s)).collect()
}
fn process_direct(&mut self, sample: T) -> T {
let mut output = T::zero();
output = output + (sample * self.ir[0]);
for i in 1..self.ir.len() {
let hist_idx = (self.history_pos + i) % self.history.len();
output = output + (self.history[hist_idx] * self.ir[i]);
}
self.history[self.history_pos] = sample;
self.history_pos = (self.history_pos + 1) % self.history.len();
output
}
fn process_fft(&mut self, sample: T) -> T {
let fft_size = self.fft_size;
self.input_buffer[self.input_pos] = sample;
self.input_pos += 1;
let mut output = T::zero();
if self.input_pos >= fft_size / 2 {
if let Some(ref mut fft) = self.fft {
let mut input_fft = vec![Complex::new(T::zero(), T::zero()); fft_size];
for (i, &val) in self.input_buffer[..self.input_pos].iter().enumerate() {
input_fft[i] = Complex::new(val, T::zero());
}
let mut input_freq = vec![Complex::new(T::zero(), T::zero()); fft_size];
fft.fft(&input_fft, &mut input_freq);
for i in 0..fft_size {
input_freq[i] = input_freq[i] * self.ir_fft[i];
}
let mut output_time = vec![Complex::new(T::zero(), T::zero()); fft_size];
fft.ifft(&input_freq, &mut output_time);
for i in 0..fft_size {
self.output_buffer[i] = self.output_buffer[i] + output_time[i].re;
}
output = self.output_buffer[0];
for i in 0..fft_size - 1 {
self.output_buffer[i] = self.output_buffer[i + 1];
}
let output_len = self.output_buffer.len();
self.output_buffer[output_len - 1] = T::zero();
self.input_pos = 0;
for val in &mut self.input_buffer {
*val = T::zero();
}
}
} else {
if self.output_buffer[0] != T::zero() {
output = self.output_buffer[0];
for i in 0..self.output_buffer.len() - 1 {
self.output_buffer[i] = self.output_buffer[i + 1];
}
let output_len = self.output_buffer.len();
self.output_buffer[output_len - 1] = T::zero();
}
}
output
}
pub fn set_ir(&mut self, ir: Vec<T>) {
*self = Self::new(ir);
}
pub fn get_ir(&self) -> &[T] {
&self.ir
}
pub fn ir_len(&self) -> usize {
self.ir.len()
}
pub fn reset(&mut self) {
self.input_pos = 0;
for val in &mut self.input_buffer {
*val = T::zero();
}
for val in &mut self.output_buffer {
*val = T::zero();
}
}
fn next_power_of_two(n: usize) -> usize {
let mut power = 1;
while power < n {
power *= 2;
}
power
}
}
impl<T: Float + FromPrimitive> Default for Convolver<T> {
fn default() -> Self {
Self::new(Vec::new())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_empty_convolver() {
let mut conv: Convolver<f32> = Convolver::new(Vec::new());
assert_eq!(conv.process(1.0), 1.0);
}
#[test]
fn test_impulse_response() {
let ir = vec![1.0];
let mut conv: Convolver<f32> = Convolver::new(ir);
assert_eq!(conv.process(1.0), 1.0);
}
#[test]
fn test_convolver_with_ir() {
let ir = vec![0.5, 0.5];
let mut conv: Convolver<f32> = Convolver::new(ir);
conv.process(1.0);
let _output = conv.process(1.0);
}
}