use crate::error::{Result, TunesError};
use crate::synthesis::sample::Sample;
use crate::track::PRIORITY_SPATIAL;
use rustfft::num_complex::Complex;
use rustfft::{Fft, FftPlanner};
use std::collections::VecDeque;
use std::sync::Arc;
#[cfg(feature = "gpu")]
use std::sync::Mutex;
#[cfg(feature = "gpu")]
use crate::gpu::{GpuConvolution, GpuDevice};
#[derive(Clone)]
pub struct ConvolutionReverb {
ir_fft: Vec<Complex<f32>>,
fft_size: usize,
block_size: usize,
hop_size: usize,
input_buffer: Vec<f32>,
output_buffer: VecDeque<f32>,
overlap_buffer: Vec<f32>,
fft_working_buffer: Vec<Complex<f32>>,
fft: Arc<dyn Fft<f32>>,
ifft: Arc<dyn Fft<f32>>,
#[cfg(feature = "gpu")]
gpu_convolution: Option<Arc<Mutex<GpuConvolution>>>,
pub mix: f32,
pub priority: u8,
sample_count: u64,
gpu_enabled: bool,
}
impl ConvolutionReverb {
pub fn from_samples(ir: &[f32], mix: f32, block_size: Option<usize>) -> Result<Self> {
if ir.is_empty() {
return Err(TunesError::AudioEngineError(
"Impulse response cannot be empty".to_string(),
));
}
let block_size = block_size.unwrap_or({
if ir.len() < 4096 {
2048
} else if ir.len() < 16384 {
4096
} else {
8192
}
});
let fft_size = (ir.len() + block_size).next_power_of_two();
let mut planner = FftPlanner::new();
let fft = planner.plan_fft_forward(fft_size);
let ifft = planner.plan_fft_inverse(fft_size);
let ir_fft = {
let mut ir_complex = vec![Complex::new(0.0, 0.0); fft_size];
for (i, &sample) in ir.iter().enumerate() {
ir_complex[i] = Complex::new(sample, 0.0);
}
fft.process(&mut ir_complex);
ir_complex
};
Ok(Self {
ir_fft,
fft_size,
block_size,
hop_size: block_size / 2,
input_buffer: Vec::with_capacity(block_size),
output_buffer: VecDeque::with_capacity(fft_size),
overlap_buffer: vec![0.0; fft_size],
fft_working_buffer: vec![Complex::new(0.0, 0.0); fft_size], fft,
ifft,
#[cfg(feature = "gpu")]
gpu_convolution: None,
mix: mix.clamp(0.0, 1.0),
priority: PRIORITY_SPATIAL, sample_count: 0,
gpu_enabled: false,
})
}
pub fn from_ir(ir_path: &str, mix: f32, block_size: Option<usize>) -> Result<Self> {
let ir_sample = Sample::from_file(ir_path)?;
let ir_mono = if ir_sample.channels == 2 {
ir_sample
.data
.chunks(2)
.map(|chunk| (chunk[0] + chunk[1]) / 2.0)
.collect::<Vec<f32>>()
} else {
ir_sample.data.as_ref().clone()
};
Self::from_samples(&ir_mono, mix, block_size)
}
pub fn process(&mut self, input: f32) -> f32 {
self.input_buffer.push(input);
if self.input_buffer.len() >= self.block_size {
self.process_block();
}
let output = self.output_buffer.pop_front().unwrap_or(0.0);
self.sample_count += 1;
input * (1.0 - self.mix) + output * self.mix
}
fn process_block(&mut self) {
#[cfg(feature = "gpu")]
if self.gpu_enabled {
if let Some(ref gpu_conv) = self.gpu_convolution {
let input_block: Vec<f32> = self
.input_buffer
.iter()
.take(self.block_size)
.copied()
.collect();
match gpu_conv
.lock()
.unwrap()
.process_block(&input_block, &self.overlap_buffer)
{
Ok((output_samples, new_overlap)) => {
for sample in output_samples {
self.output_buffer.push_back(sample);
}
self.overlap_buffer = new_overlap;
self.input_buffer.drain(0..self.hop_size);
return; }
Err(e) => {
eprintln!("GPU convolution error (falling back to CPU): {}", e);
self.gpu_enabled = false;
}
}
}
}
for i in 0..self.fft_size {
self.fft_working_buffer[i] = Complex::new(0.0, 0.0);
}
for (i, &sample) in self.input_buffer.iter().enumerate().take(self.block_size) {
self.fft_working_buffer[i] = Complex::new(sample, 0.0);
}
self.fft.process(&mut self.fft_working_buffer);
for i in 0..self.fft_size {
self.fft_working_buffer[i] *= self.ir_fft[i];
}
self.ifft.process(&mut self.fft_working_buffer);
let scale = 1.0 / (self.fft_size as f32);
for i in 0..self.fft_size {
let sample = self.fft_working_buffer[i].re * scale;
let output_sample = sample + self.overlap_buffer[i];
self.output_buffer.push_back(output_sample);
self.overlap_buffer[i] = if i < self.fft_size - self.block_size {
self.fft_working_buffer[i + self.block_size].re * scale
} else {
0.0
};
}
self.input_buffer.drain(0..self.hop_size);
}
pub fn process_block_direct(&mut self, buffer: &mut [f32]) {
for sample in buffer.iter_mut() {
*sample = self.process(*sample);
}
}
pub fn reset(&mut self) {
self.input_buffer.clear();
self.output_buffer.clear();
self.overlap_buffer.fill(0.0);
self.sample_count = 0;
}
pub fn mix(&self) -> f32 {
self.mix
}
pub fn set_mix(&mut self, mix: f32) {
self.mix = mix.clamp(0.0, 1.0);
}
#[cfg(feature = "gpu")]
pub fn enable_gpu(&mut self) -> Result<()> {
use std::sync::Mutex;
let gpu_device = GpuDevice::new().map_err(|e| {
TunesError::AudioEngineError(format!("Failed to initialize GPU: {}", e))
})?;
let gpu_conv =
GpuConvolution::new(gpu_device, &self.ir_fft, self.fft_size, self.block_size).map_err(
|e| {
TunesError::AudioEngineError(format!("Failed to create GPU convolution: {}", e))
},
)?;
self.gpu_convolution = Some(Arc::new(Mutex::new(gpu_conv)));
self.gpu_enabled = true;
Ok(())
}
pub fn is_gpu_enabled(&self) -> bool {
self.gpu_enabled
}
}
impl std::fmt::Debug for ConvolutionReverb {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ConvolutionReverb")
.field("fft_size", &self.fft_size)
.field("block_size", &self.block_size)
.field("mix", &self.mix)
.finish()
}
}
#[derive(Debug, Clone)]
pub struct IRParams {
pub room_dimensions: (f32, f32, f32),
pub rt60: f32,
pub damping: f32,
pub early_density: f32,
pub sample_rate: f32,
}
impl IRParams {
pub fn small_room() -> Self {
Self {
room_dimensions: (4.0, 5.0, 2.5),
rt60: 0.3,
damping: 0.6,
early_density: 0.7,
sample_rate: 44100.0,
}
}
pub fn concert_hall() -> Self {
Self {
room_dimensions: (40.0, 30.0, 15.0),
rt60: 2.5,
damping: 0.4,
early_density: 0.9,
sample_rate: 44100.0,
}
}
pub fn cathedral() -> Self {
Self {
room_dimensions: (60.0, 40.0, 25.0),
rt60: 4.5,
damping: 0.5,
early_density: 0.8,
sample_rate: 44100.0,
}
}
pub fn plate() -> Self {
Self {
room_dimensions: (2.0, 1.5, 0.01),
rt60: 2.0,
damping: 0.2,
early_density: 1.0,
sample_rate: 44100.0,
}
}
pub fn spring() -> Self {
Self {
room_dimensions: (0.5, 0.1, 0.1),
rt60: 1.0,
damping: 0.3,
early_density: 0.6,
sample_rate: 44100.0,
}
}
}
impl ConvolutionReverb {
pub fn from_params(params: IRParams, mix: f32) -> Result<Self> {
let ir = generate_ir(¶ms);
Self::from_samples(&ir, mix, None)
}
}
pub fn generate_ir(params: &IRParams) -> Vec<f32> {
let duration = params.rt60 * 1.5; let num_samples = (duration * params.sample_rate) as usize;
let mut ir = vec![0.0; num_samples];
ir[0] = 1.0;
add_early_reflections(&mut ir, params);
add_diffuse_tail(&mut ir, params);
let max = ir.iter().map(|x| x.abs()).fold(0.0f32, f32::max);
if max > 0.0 {
for sample in &mut ir {
*sample /= max;
}
}
ir
}
fn add_early_reflections(ir: &mut [f32], params: &IRParams) {
let (length, width, height) = params.room_dimensions;
let speed_of_sound = 343.0;
let reflections = [
(length / speed_of_sound, 0.8), (width / speed_of_sound, 0.8), (height / speed_of_sound, 0.7), (length * 1.5 / speed_of_sound, 0.6), (width * 1.5 / speed_of_sound, 0.6), (height * 2.0 / speed_of_sound, 0.5), ];
for (delay_seconds, amplitude) in reflections {
let delay_samples = (delay_seconds * params.sample_rate) as usize;
if delay_samples < ir.len() {
ir[delay_samples] += amplitude;
}
}
if params.early_density > 0.5 {
use rand::Rng;
let mut rng = rand::rng();
let num_extra = (params.early_density * 20.0) as usize;
for _ in 0..num_extra {
let delay = rng.random_range(0.01..0.1); let delay_samples = (delay * params.sample_rate) as usize;
let amplitude = rng.random_range(0.1..0.5);
if delay_samples < ir.len() {
ir[delay_samples] += amplitude;
}
}
}
}
fn add_diffuse_tail(ir: &mut [f32], params: &IRParams) {
use rand::Rng;
let mut rng = rand::rng();
let start_sample = (0.05 * params.sample_rate) as usize;
let decay_rate = (-60.0 / params.rt60) / params.sample_rate;
let decay_coefficient = 10.0f32.powf(decay_rate / 20.0);
let damping_coeff = 1.0 - params.damping;
let mut lowpass_state = 0.0;
for (offset, sample) in ir.iter_mut().enumerate().skip(start_sample) {
let noise = rng.random_range(-1.0..1.0);
let decay = decay_coefficient.powf(offset as f32);
lowpass_state = lowpass_state * params.damping + noise * damping_coeff;
*sample += lowpass_state * decay * 0.3;
}
}
pub struct Convolution;
impl Convolution {
pub fn from_file(ir_path: &str, mix: f32) -> Result<ConvolutionReverb> {
ConvolutionReverb::from_ir(ir_path, mix, None)
}
pub fn from_params(params: IRParams, mix: f32) -> Result<ConvolutionReverb> {
ConvolutionReverb::from_params(params, mix)
}
}
pub mod presets {
use super::*;
pub fn small_room(mix: f32) -> Result<ConvolutionReverb> {
ConvolutionReverb::from_params(IRParams::small_room(), mix)
}
pub fn concert_hall(mix: f32) -> Result<ConvolutionReverb> {
ConvolutionReverb::from_params(IRParams::concert_hall(), mix)
}
pub fn cathedral(mix: f32) -> Result<ConvolutionReverb> {
ConvolutionReverb::from_params(IRParams::cathedral(), mix)
}
pub fn plate(mix: f32) -> Result<ConvolutionReverb> {
ConvolutionReverb::from_params(IRParams::plate(), mix)
}
pub fn spring(mix: f32) -> Result<ConvolutionReverb> {
ConvolutionReverb::from_params(IRParams::spring(), mix)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_create_from_samples() {
let ir = vec![1.0, 0.5, 0.25, 0.1];
let reverb = ConvolutionReverb::from_samples(&ir, 0.5, None);
assert!(reverb.is_ok());
}
#[test]
fn test_process_sample() {
let ir = vec![1.0, 0.5, 0.25];
let mut reverb = ConvolutionReverb::from_samples(&ir, 0.5, Some(256)).unwrap();
for _ in 0..1000 {
let output = reverb.process(1.0);
assert!(output.is_finite());
}
}
#[test]
fn test_empty_ir_fails() {
let ir: Vec<f32> = vec![];
let result = ConvolutionReverb::from_samples(&ir, 0.5, None);
assert!(result.is_err());
}
#[test]
fn test_reset() {
let ir = vec![1.0, 0.5, 0.25];
let mut reverb = ConvolutionReverb::from_samples(&ir, 0.5, None).unwrap();
for _ in 0..100 {
reverb.process(1.0);
}
reverb.reset();
assert_eq!(reverb.sample_count, 0);
}
#[test]
fn test_generate_ir() {
let params = IRParams::small_room();
let ir = generate_ir(¶ms);
assert!(!ir.is_empty());
assert!(ir[0].abs() > 0.0); }
#[test]
fn test_from_params() {
let reverb = ConvolutionReverb::from_params(IRParams::cathedral(), 0.5);
assert!(reverb.is_ok());
}
#[test]
fn test_ir_presets() {
for params in [
IRParams::small_room(),
IRParams::concert_hall(),
IRParams::cathedral(),
IRParams::plate(),
IRParams::spring(),
] {
let ir = generate_ir(¶ms);
assert!(!ir.is_empty());
assert!(ir.iter().all(|x| x.is_finite()));
}
}
}