use crate::analysis::fft::fft_in_place;
use crate::buffer::AudioBuffer;
#[must_use]
#[derive(Debug, Clone)]
pub struct ConvolutionReverb {
fft_size: usize,
block_size: usize,
ir_partitions: Vec<(Vec<f64>, Vec<f64>)>,
input_buffers: Vec<Vec<f32>>,
overlap: Vec<Vec<f64>>,
write_pos: usize,
mix: f32,
channels: usize,
scratch_real: Vec<f64>,
scratch_imag: Vec<f64>,
fdl: Vec<Vec<(Vec<f64>, Vec<f64>)>>,
fdl_pos: usize,
}
impl ConvolutionReverb {
pub fn new(ir_samples: &[f32], mix: f32, sample_rate: u32) -> Self {
Self::with_block_size(ir_samples, 512, mix, sample_rate)
}
pub fn with_block_size(
ir_samples: &[f32],
block_size: usize,
mix: f32,
sample_rate: u32,
) -> Self {
let block_size = block_size.max(1).next_power_of_two();
let fft_size = block_size * 2;
tracing::debug!(
ir_len = ir_samples.len(),
block_size,
fft_size,
mix,
sample_rate,
"ConvolutionReverb::new"
);
let ir_partitions = partition_ir(ir_samples, block_size, fft_size);
Self {
fft_size,
block_size,
ir_partitions,
input_buffers: Vec::new(), overlap: Vec::new(),
write_pos: 0,
mix: mix.clamp(0.0, 1.0),
channels: 0,
scratch_real: vec![0.0; fft_size],
scratch_imag: vec![0.0; fft_size],
fdl: Vec::new(),
fdl_pos: 0,
}
}
pub fn set_mix(&mut self, mix: f32) {
self.mix = mix.clamp(0.0, 1.0);
}
#[must_use]
pub fn mix(&self) -> f32 {
self.mix
}
pub fn set_ir(&mut self, ir_samples: &[f32]) {
tracing::debug!(ir_len = ir_samples.len(), "ConvolutionReverb::set_ir");
self.ir_partitions = partition_ir(ir_samples, self.block_size, self.fft_size);
self.reset();
}
pub fn reset(&mut self) {
for buf in &mut self.input_buffers {
buf.fill(0.0);
}
for ov in &mut self.overlap {
ov.fill(0.0);
}
for ch_fdl in &mut self.fdl {
for (r, i) in ch_fdl.iter_mut() {
r.fill(0.0);
i.fill(0.0);
}
}
self.write_pos = 0;
self.fdl_pos = 0;
}
#[inline]
pub fn process(&mut self, buf: &mut AudioBuffer) {
if self.ir_partitions.is_empty() {
return;
}
let ch = buf.channels as usize;
self.ensure_channels(ch);
let mix = self.mix;
let dry = 1.0 - mix;
let frames = buf.frames;
let mut frame = 0;
while frame < frames {
let remaining_in_block = self.block_size - self.write_pos;
let to_copy = remaining_in_block.min(frames - frame);
for c in 0..ch {
for i in 0..to_copy {
self.input_buffers[c][self.write_pos + i] = buf.samples[(frame + i) * ch + c];
}
}
self.write_pos += to_copy;
if self.write_pos >= self.block_size {
for c in 0..ch {
self.process_block(c);
}
self.fdl_pos = (self.fdl_pos + 1) % self.ir_partitions.len().max(1);
self.write_pos = 0;
let block_start = frame + to_copy - self.block_size;
for c in 0..ch {
for i in 0..self.block_size {
let src_idx = (block_start + i) * ch + c;
if src_idx < buf.samples.len() {
let wet = self.overlap[c][i] as f32;
buf.samples[src_idx] = buf.samples[src_idx] * dry + wet * mix;
}
}
let tail_len = self.fft_size - self.block_size;
for i in 0..tail_len {
self.overlap[c][i] = self.overlap[c][self.block_size + i];
}
for i in tail_len..self.fft_size {
self.overlap[c][i] = 0.0;
}
}
}
frame += to_copy;
}
}
fn ensure_channels(&mut self, ch: usize) {
if self.channels == ch {
return;
}
self.channels = ch;
let num_partitions = self.ir_partitions.len().max(1);
self.input_buffers = vec![vec![0.0; self.block_size]; ch];
self.overlap = vec![vec![0.0; self.fft_size]; ch];
self.fdl = (0..ch)
.map(|_| {
(0..num_partitions)
.map(|_| (vec![0.0; self.fft_size], vec![0.0; self.fft_size]))
.collect()
})
.collect();
self.write_pos = 0;
self.fdl_pos = 0;
}
fn process_block(&mut self, channel: usize) {
let fft_size = self.fft_size;
let num_partitions = self.ir_partitions.len();
self.scratch_real.fill(0.0);
self.scratch_imag.fill(0.0);
for i in 0..self.block_size {
self.scratch_real[i] = self.input_buffers[channel][i] as f64;
}
if !fft_in_place(&mut self.scratch_real, &mut self.scratch_imag) {
return;
}
let fdl_slot = &mut self.fdl[channel][self.fdl_pos];
fdl_slot.0.copy_from_slice(&self.scratch_real);
fdl_slot.1.copy_from_slice(&self.scratch_imag);
let mut acc_real = vec![0.0f64; fft_size];
let mut acc_imag = vec![0.0f64; fft_size];
for k in 0..num_partitions {
let fdl_idx = (self.fdl_pos + num_partitions - k) % num_partitions;
let (ref fdl_r, ref fdl_i) = self.fdl[channel][fdl_idx];
let (ref ir_r, ref ir_i) = self.ir_partitions[k];
for bin in 0..fft_size {
acc_real[bin] += fdl_r[bin] * ir_r[bin] - fdl_i[bin] * ir_i[bin];
acc_imag[bin] += fdl_r[bin] * ir_i[bin] + fdl_i[bin] * ir_r[bin];
}
}
for v in &mut acc_imag {
*v = -*v;
}
if !fft_in_place(&mut acc_real, &mut acc_imag) {
return;
}
let scale = 1.0 / fft_size as f64;
for (ov, &ar) in self.overlap[channel][..fft_size].iter_mut().zip(&acc_real) {
*ov += ar * scale;
}
}
}
fn partition_ir(ir: &[f32], block_size: usize, fft_size: usize) -> Vec<(Vec<f64>, Vec<f64>)> {
if ir.is_empty() {
return Vec::new();
}
let num_partitions = ir.len().div_ceil(block_size);
let mut partitions = Vec::with_capacity(num_partitions);
for p in 0..num_partitions {
let start = p * block_size;
let end = (start + block_size).min(ir.len());
let mut real = vec![0.0f64; fft_size];
let mut imag = vec![0.0f64; fft_size];
for (i, &s) in ir[start..end].iter().enumerate() {
real[i] = s as f64;
}
fft_in_place(&mut real, &mut imag);
partitions.push((real, imag));
}
partitions
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn identity_ir() {
let ir = vec![1.0];
let mut reverb = ConvolutionReverb::with_block_size(&ir, 4, 1.0, 44100);
let samples = vec![1.0, 0.5, 0.25, 0.0, 0.0, 0.0, 0.0, 0.0];
let mut buf = AudioBuffer::from_interleaved(samples.clone(), 1, 44100).unwrap();
reverb.process(&mut buf);
for s in buf.samples() {
assert!(s.is_finite(), "output not finite");
}
}
#[test]
fn delay_ir() {
let ir = vec![0.0, 0.0, 0.0, 1.0];
let mut reverb = ConvolutionReverb::with_block_size(&ir, 4, 1.0, 44100);
let mut samples = vec![0.0; 16];
samples[0] = 1.0; let mut buf = AudioBuffer::from_interleaved(samples, 1, 44100).unwrap();
reverb.process(&mut buf);
assert!(buf.samples().iter().any(|s| s.abs() > 0.5));
assert!(buf.samples().iter().all(|s| s.is_finite()));
}
#[test]
fn stereo_convolution() {
let ir = vec![1.0, 0.5, 0.25];
let mut reverb = ConvolutionReverb::with_block_size(&ir, 4, 1.0, 44100);
let samples = vec![
1.0, 0.5, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
];
let mut buf = AudioBuffer::from_interleaved(samples, 2, 44100).unwrap();
reverb.process(&mut buf);
assert!(buf.samples().iter().all(|s| s.is_finite()));
}
#[test]
fn mix_control() {
let ir = vec![1.0];
let mut reverb = ConvolutionReverb::with_block_size(&ir, 4, 0.5, 44100);
let samples = vec![1.0; 8];
let mut buf = AudioBuffer::from_interleaved(samples, 1, 44100).unwrap();
reverb.process(&mut buf);
assert!(buf.samples().iter().all(|s| s.is_finite()));
}
#[test]
fn empty_ir_passthrough() {
let ir: Vec<f32> = vec![];
let mut reverb = ConvolutionReverb::new(&ir, 1.0, 44100);
let original = vec![0.5; 1024];
let mut buf = AudioBuffer::from_interleaved(original.clone(), 1, 44100).unwrap();
reverb.process(&mut buf);
assert_eq!(buf.samples(), &original);
}
#[test]
fn long_ir() {
let ir: Vec<f32> = (0..2048).map(|i| (-0.001 * i as f32).exp()).collect();
let mut reverb = ConvolutionReverb::with_block_size(&ir, 256, 1.0, 44100);
let samples: Vec<f32> = (0..4096)
.map(|i| (2.0 * std::f32::consts::PI * 440.0 * i as f32 / 44100.0).sin() * 0.5)
.collect();
let mut buf = AudioBuffer::from_interleaved(samples, 1, 44100).unwrap();
reverb.process(&mut buf);
assert!(buf.samples().iter().all(|s| s.is_finite()));
assert!(buf.rms() > 0.0);
}
#[test]
fn reset_clears_state() {
let ir = vec![1.0, 0.5, 0.25];
let mut reverb = ConvolutionReverb::with_block_size(&ir, 4, 1.0, 44100);
let mut buf = AudioBuffer::from_interleaved(vec![1.0; 8], 1, 44100).unwrap();
reverb.process(&mut buf);
reverb.reset();
let mut buf = AudioBuffer::from_interleaved(vec![0.0; 8], 1, 44100).unwrap();
reverb.process(&mut buf);
assert!(buf.samples().iter().all(|s| s.is_finite()));
}
#[test]
fn set_ir_replaces() {
let mut reverb = ConvolutionReverb::new(&[1.0], 1.0, 44100);
reverb.set_ir(&[1.0, 0.5, 0.25, 0.125]);
let mut buf = AudioBuffer::from_interleaved(vec![1.0; 2048], 1, 44100).unwrap();
reverb.process(&mut buf);
assert!(buf.samples().iter().all(|s| s.is_finite()));
}
}