use crate::{AudioEffect, EffectError, Result};
use oxifft::Complex;
pub struct ConvolutionReverb {
ir_fft: Vec<Complex<f32>>,
#[allow(dead_code)]
ir_length: usize,
fft_size: usize,
input_buffer: Vec<f32>,
input_fft: Vec<Complex<f32>>,
output_buffer: Vec<f32>,
tail_buffer: Vec<f32>,
input_pos: usize,
output_pos: usize,
wet: f32,
dry: f32,
#[allow(dead_code)]
sample_rate: f32,
}
impl ConvolutionReverb {
pub fn new(impulse_response: &[f32], sample_rate: f32) -> Result<Self> {
if impulse_response.is_empty() {
return Err(EffectError::InvalidParameter(
"Impulse response cannot be empty".into(),
));
}
if impulse_response.len() > 100_000 {
return Err(EffectError::InvalidParameter(
"Impulse response too long (max 100k samples)".into(),
));
}
let ir_length = impulse_response.len();
let fft_size = (ir_length * 2).next_power_of_two();
let mut ir_padded: Vec<Complex<f32>> = impulse_response
.iter()
.map(|&x| Complex::new(x, 0.0))
.collect();
ir_padded.resize(fft_size, Complex::new(0.0, 0.0));
let ir_fft = oxifft::fft(&ir_padded);
Ok(Self {
ir_fft,
ir_length,
fft_size,
input_buffer: vec![0.0; fft_size],
input_fft: vec![Complex::new(0.0, 0.0); fft_size],
output_buffer: vec![0.0; fft_size],
tail_buffer: vec![0.0; fft_size],
input_pos: 0,
output_pos: 0,
wet: 0.5,
dry: 0.5,
sample_rate,
})
}
pub fn set_wet(&mut self, wet: f32) {
self.wet = wet.clamp(0.0, 1.0);
}
pub fn set_dry(&mut self, dry: f32) {
self.dry = dry.clamp(0.0, 1.0);
}
fn process_block(&mut self) {
for (i, &sample) in self.input_buffer.iter().enumerate() {
self.input_fft[i] = Complex::new(sample, 0.0);
}
let fft_result = oxifft::fft(&self.input_fft);
let result_fft_freq: Vec<Complex<f32>> = fft_result
.iter()
.zip(self.ir_fft.iter())
.map(|(&a, &b)| a * b)
.collect();
let result_fft = oxifft::ifft(&result_fft_freq);
#[allow(clippy::cast_precision_loss)]
let scale = 1.0 / self.fft_size as f32;
for (i, val) in result_fft.iter().enumerate().take(self.fft_size) {
self.output_buffer[i] = val.re * scale;
}
for i in 0..self.fft_size {
self.output_buffer[i] += self.tail_buffer[i];
}
for i in 0..self.fft_size / 2 {
self.tail_buffer[i] = self.output_buffer[self.fft_size / 2 + i];
}
for i in self.fft_size / 2..self.fft_size {
self.tail_buffer[i] = 0.0;
}
self.output_pos = 0;
}
}
impl AudioEffect for ConvolutionReverb {
const EFFECT_ID: &'static str = "convolution_reverb";
fn process_sample(&mut self, input: f32) -> f32 {
self.input_buffer[self.input_pos] = input;
self.input_pos += 1;
if self.input_pos >= self.fft_size / 2 {
self.process_block();
self.input_pos = 0;
for i in self.fft_size / 2..self.fft_size {
self.input_buffer[i] = 0.0;
}
}
let wet_sample = if self.output_pos < self.output_buffer.len() {
self.output_buffer[self.output_pos]
} else {
0.0
};
self.output_pos += 1;
wet_sample * self.wet + input * self.dry
}
fn reset(&mut self) {
self.input_buffer.fill(0.0);
self.output_buffer.fill(0.0);
self.tail_buffer.fill(0.0);
self.input_fft.fill(Complex::new(0.0, 0.0));
self.input_pos = 0;
self.output_pos = 0;
}
fn latency_samples(&self) -> usize {
self.fft_size / 2
}
}
pub struct DoubleBufferConvolver {
ir_fft: Vec<Complex<f32>>,
fft_size: usize,
block_size: usize,
overlap: Vec<f32>,
buffer_a: Vec<f32>,
buffer_b: Vec<f32>,
active_buffer: usize,
}
impl DoubleBufferConvolver {
#[must_use]
pub fn new(impulse_response: &[f32], block_size: usize) -> Self {
assert!(
!impulse_response.is_empty(),
"impulse_response must not be empty"
);
let block_size = block_size.max(1).next_power_of_two();
let ir_len = impulse_response.len();
let min_fft = ir_len + block_size;
let fft_size = min_fft.next_power_of_two();
let mut ir_padded: Vec<Complex<f32>> = impulse_response
.iter()
.map(|&x| Complex::new(x, 0.0))
.collect();
ir_padded.resize(fft_size, Complex::new(0.0, 0.0));
let ir_fft = oxifft::fft(&ir_padded);
let overlap_len = fft_size - block_size;
Self {
ir_fft,
fft_size,
block_size,
overlap: vec![0.0; overlap_len],
buffer_a: vec![0.0; fft_size],
buffer_b: vec![0.0; fft_size],
active_buffer: 0,
}
}
pub fn process_block(&mut self, input: &[f32]) -> &[f32] {
assert_eq!(
input.len(),
self.block_size,
"input.len() must equal block_size"
);
let mut input_fft: Vec<Complex<f32>> =
input.iter().map(|&x| Complex::new(x, 0.0)).collect();
input_fft.resize(self.fft_size, Complex::new(0.0, 0.0));
let input_spectrum = oxifft::fft(&input_fft);
let convolved: Vec<Complex<f32>> = input_spectrum
.iter()
.zip(self.ir_fft.iter())
.map(|(&a, &b)| Complex::new(a.re * b.re - a.im * b.im, a.re * b.im + a.im * b.re))
.collect();
let time_domain = oxifft::ifft(&convolved);
let fill_idx = 1 - self.active_buffer;
let fill_buf = if fill_idx == 0 {
&mut self.buffer_a
} else {
&mut self.buffer_b
};
for (i, s) in time_domain.iter().enumerate().take(self.fft_size) {
fill_buf[i] = s.re;
}
let overlap_len = self.overlap.len();
for i in 0..overlap_len {
fill_buf[i] += self.overlap[i];
}
for i in 0..overlap_len {
self.overlap[i] = fill_buf[self.block_size + i];
}
for i in self.block_size..self.fft_size {
fill_buf[i] = 0.0;
}
self.active_buffer = fill_idx;
let active_buf = if self.active_buffer == 0 {
&self.buffer_a
} else {
&self.buffer_b
};
&active_buf[..self.block_size]
}
#[must_use]
pub fn block_size(&self) -> usize {
self.block_size
}
pub fn clear(&mut self) {
self.overlap.iter_mut().for_each(|s| *s = 0.0);
self.buffer_a.iter_mut().for_each(|s| *s = 0.0);
self.buffer_b.iter_mut().for_each(|s| *s = 0.0);
self.active_buffer = 0;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_convolution_reverb_creation() {
let ir = vec![1.0, 0.5, 0.25, 0.125]; let reverb = ConvolutionReverb::new(&ir, 48000.0);
assert!(reverb.is_ok());
}
#[test]
fn test_convolution_reverb_empty_ir() {
let ir: Vec<f32> = vec![];
let result = ConvolutionReverb::new(&ir, 48000.0);
assert!(result.is_err());
}
#[test]
fn test_convolution_reverb_process() {
let ir = vec![1.0; 100]; let mut reverb = ConvolutionReverb::new(&ir, 48000.0).expect("test expectation failed");
let output = reverb.process_sample(1.0);
assert!(output.is_finite());
for _ in 0..1000 {
let out = reverb.process_sample(0.0);
assert!(out.is_finite());
}
}
#[test]
fn test_convolution_wet_dry() {
let ir = vec![0.5; 50];
let mut reverb = ConvolutionReverb::new(&ir, 48000.0).expect("test expectation failed");
reverb.set_wet(0.0);
reverb.set_dry(1.0);
for _ in 0..100 {
reverb.process_sample(1.0);
}
let output = reverb.process_sample(1.0);
assert!((output - 1.0).abs() < 0.5);
}
#[test]
fn test_double_buffer_convolver_delta() {
const BLOCK: usize = 64;
let mut ir = vec![0.0_f32; BLOCK];
ir[0] = 1.0;
let mut conv = DoubleBufferConvolver::new(&ir, BLOCK);
let mut input_block = vec![0.0_f32; BLOCK];
input_block[0] = 1.0;
let out0 = conv.process_block(&input_block);
assert_eq!(out0.len(), BLOCK, "output block length must equal BLOCK");
let peak = out0.iter().cloned().fold(0.0_f32, f32::max);
assert!(
peak > 0.5,
"delta IR convolution peak should be > 0.5 (got {peak:.4})"
);
for (i, &s) in out0.iter().enumerate() {
assert!(s.is_finite(), "out0[{i}] is not finite: {s}");
}
let silence = vec![0.0_f32; BLOCK];
let out1 = conv.process_block(&silence);
assert_eq!(out1.len(), BLOCK, "second block length must equal BLOCK");
for (i, &s) in out1.iter().enumerate() {
assert!(s.is_finite(), "out1[{i}] is not finite: {s}");
}
}
#[test]
fn test_double_buffer_convolver_matches_direct() {
use std::f32::consts::TAU;
const N_BLOCKS: usize = 8;
const BLOCK: usize = 32;
const IR_LEN: usize = 16;
const TOTAL: usize = N_BLOCKS * BLOCK;
let ir: Vec<f32> = (0..IR_LEN)
.map(|i| {
let x = (i as f32 - IR_LEN as f32 / 2.0) / (IR_LEN as f32 / 4.0);
(-x * x / 2.0).exp() / IR_LEN as f32
})
.collect();
let input: Vec<f32> = (0..TOTAL)
.map(|i| (TAU * 440.0 * i as f32 / 48_000.0).sin() * 0.5)
.collect();
let mut direct = vec![0.0_f32; TOTAL + IR_LEN];
for (n, &x) in input.iter().enumerate() {
for (k, &h) in ir.iter().enumerate() {
direct[n + k] += x * h;
}
}
let mut conv = DoubleBufferConvolver::new(&ir, BLOCK);
let mut ola_out = Vec::with_capacity(TOTAL);
for block_idx in 0..N_BLOCKS {
let start = block_idx * BLOCK;
let chunk = &input[start..start + BLOCK];
let out = conv.process_block(chunk);
ola_out.extend_from_slice(out);
}
let skip = IR_LEN;
for (i, (&ola, &dir)) in ola_out[skip..]
.iter()
.zip(direct[skip..].iter())
.enumerate()
.take(TOTAL - skip)
{
let err = (ola - dir).abs();
assert!(
err < 1e-3,
"OLA vs direct mismatch at sample {}: ola={:.6}, dir={:.6}, err={:.6}",
skip + i,
ola,
dir,
err
);
}
}
}