use std::any::Any;
use fft_convolver::FFTConvolver;
use crate::buffer::AudioBuffer;
use crate::context::{AudioContextRegistration, BaseAudioContext};
use crate::render::{
AudioParamValues, AudioProcessor, AudioRenderQuantum, AudioWorkletGlobalScope,
};
use crate::RENDER_QUANTUM_SIZE;
use super::{AudioNode, AudioNodeOptions, ChannelConfig, ChannelCountMode, ChannelInterpretation};
fn normalize_buffer(buffer: &AudioBuffer) -> f32 {
let gain_calibration = 0.00125;
let gain_calibration_sample_rate = 44100.;
let min_power = 0.000125;
let number_of_channels = buffer.number_of_channels();
let length = buffer.length();
let sample_rate = buffer.sample_rate();
let mut power: f32 = buffer
.channels()
.iter()
.map(|c| c.as_slice().iter().map(|&s| s * s).sum::<f32>())
.sum();
power = (power / (number_of_channels * length) as f32).sqrt();
if !power.is_finite() || power.is_nan() || power < min_power {
power = min_power;
}
let mut scale = 1. / power;
scale *= gain_calibration;
scale *= gain_calibration_sample_rate / sample_rate;
if number_of_channels == 4 {
scale *= 0.5;
}
scale
}
#[derive(Clone, Debug)]
pub struct ConvolverOptions {
pub buffer: Option<AudioBuffer>,
pub disable_normalization: bool,
pub audio_node_options: AudioNodeOptions,
}
impl Default for ConvolverOptions {
fn default() -> Self {
Self {
buffer: None,
disable_normalization: false,
audio_node_options: AudioNodeOptions {
channel_count: 2,
channel_count_mode: ChannelCountMode::ClampedMax,
channel_interpretation: ChannelInterpretation::Speakers,
},
}
}
}
#[track_caller]
#[inline(always)]
fn assert_valid_channel_count(count: usize) {
assert!(
count <= 2,
"NotSupportedError - ConvolverNode channel count cannot be greater than two"
);
}
#[track_caller]
#[inline(always)]
fn assert_valid_channel_count_mode(mode: ChannelCountMode) {
assert_ne!(
mode,
ChannelCountMode::Max,
"NotSupportedError - ConvolverNode channel count mode cannot be set to max"
);
}
#[derive(Debug)]
pub struct ConvolverNode {
registration: AudioContextRegistration,
channel_config: ChannelConfig,
normalize: bool,
buffer: Option<AudioBuffer>,
}
impl AudioNode for ConvolverNode {
fn registration(&self) -> &AudioContextRegistration {
&self.registration
}
fn channel_config(&self) -> &ChannelConfig {
&self.channel_config
}
fn number_of_inputs(&self) -> usize {
1
}
fn number_of_outputs(&self) -> usize {
1
}
fn set_channel_count(&self, count: usize) {
assert_valid_channel_count(count);
self.channel_config.set_count(count, self.registration());
}
fn set_channel_count_mode(&self, mode: ChannelCountMode) {
assert_valid_channel_count_mode(mode);
self.channel_config
.set_count_mode(mode, self.registration());
}
}
impl ConvolverNode {
pub fn new<C: BaseAudioContext>(context: &C, options: ConvolverOptions) -> Self {
let ConvolverOptions {
buffer,
disable_normalization,
audio_node_options,
} = options;
assert_valid_channel_count(audio_node_options.channel_count);
assert_valid_channel_count_mode(audio_node_options.channel_count_mode);
let mut node = context.base().register(move |registration| {
let renderer = ConvolverRenderer {
convolvers: None,
impulse_length: 0,
impulse_number_of_channels: 0,
tail_count: 0,
};
let node = Self {
registration,
channel_config: audio_node_options.into(),
normalize: !disable_normalization,
buffer: None,
};
(node, Box::new(renderer))
});
if let Some(buffer) = buffer {
node.set_buffer(buffer);
}
node
}
pub fn buffer(&self) -> Option<&AudioBuffer> {
self.buffer.as_ref()
}
pub fn set_buffer(&mut self, buffer: AudioBuffer) {
let sample_rate = buffer.sample_rate();
assert_eq!(
sample_rate,
self.context().sample_rate(),
"NotSupportedError - sample rate of the convolution buffer must match the audio context"
);
let number_of_channels = buffer.number_of_channels();
assert!(
[1, 2, 4].contains(&number_of_channels),
"NotSupportedError - the convolution buffer must consist of 1, 2 or 4 channels"
);
let scale = if self.normalize {
normalize_buffer(&buffer)
} else {
1.
};
let mut convolvers = Vec::<FFTConvolver<f32>>::new();
let partition_size = RENDER_QUANTUM_SIZE * 8;
for index in 0..number_of_channels.max(2) {
let channel = index.min(number_of_channels - 1);
let mut scaled_channel = vec![0.; buffer.length()];
scaled_channel
.iter_mut()
.zip(buffer.get_channel_data(channel))
.for_each(|(o, i)| *o = *i * scale);
let mut convolver = FFTConvolver::<f32>::default();
convolver
.init(partition_size, &scaled_channel)
.expect("Unable to initialize convolution engine");
convolvers.push(convolver);
}
let msg = ConvolverInfosMessage {
convolvers: Some(convolvers),
impulse_length: buffer.length(),
impulse_number_of_channels: number_of_channels,
};
self.registration.post_message(msg);
self.buffer = Some(buffer);
}
pub fn normalize(&self) -> bool {
self.normalize
}
pub fn set_normalize(&mut self, value: bool) {
self.normalize = value;
}
}
struct ConvolverInfosMessage {
convolvers: Option<Vec<FFTConvolver<f32>>>,
impulse_length: usize,
impulse_number_of_channels: usize,
}
struct ConvolverRenderer {
convolvers: Option<Vec<FFTConvolver<f32>>>,
impulse_length: usize,
impulse_number_of_channels: usize,
tail_count: usize,
}
impl AudioProcessor for ConvolverRenderer {
fn process(
&mut self,
inputs: &[AudioRenderQuantum],
outputs: &mut [AudioRenderQuantum],
_params: AudioParamValues<'_>,
_scope: &AudioWorkletGlobalScope,
) -> bool {
let input = &inputs[0];
let output = &mut outputs[0];
output.force_mono();
let convolvers = match &mut self.convolvers {
None => {
*output = input.clone();
return !input.is_silent();
}
Some(convolvers) => convolvers,
};
match (input.number_of_channels(), self.impulse_number_of_channels) {
(1, 1) => {
output.set_number_of_channels(1);
let i = &input.channel_data(0)[..];
let o = &mut output.channel_data_mut(0)[..];
let _ = convolvers[0].process(i, o);
}
(1, 2) => {
output.set_number_of_channels(2);
let i = &input.channel_data(0)[..];
let o_left = &mut output.channel_data_mut(0)[..];
let _ = convolvers[0].process(i, o_left);
let o_right = &mut output.channel_data_mut(1)[..];
let _ = convolvers[1].process(i, o_right);
}
(2, 1) => {
output.set_number_of_channels(2);
let i_left = &input.channel_data(0)[..];
let o_left = &mut output.channel_data_mut(0)[..];
let _ = convolvers[0].process(i_left, o_left);
let i_right = &input.channel_data(1)[..];
let o_right = &mut output.channel_data_mut(1)[..];
let _ = convolvers[1].process(i_right, o_right);
}
(2, 2) => {
output.set_number_of_channels(2);
let i_left = &input.channel_data(0)[..];
let o_left = &mut output.channel_data_mut(0)[..];
let _ = convolvers[0].process(i_left, o_left);
let i_right = &input.channel_data(1)[..];
let o_right = &mut output.channel_data_mut(1)[..];
let _ = convolvers[1].process(i_right, o_right);
}
(2, 4) => {
output.set_number_of_channels(4);
let i_left = &input.channel_data(0)[..];
let o_0 = &mut output.channel_data_mut(0)[..];
let _ = convolvers[0].process(i_left, o_0);
let o_1 = &mut output.channel_data_mut(1)[..];
let _ = convolvers[1].process(i_left, o_1);
let i_right = &input.channel_data(1)[..];
let o_2 = &mut output.channel_data_mut(2)[..];
let _ = convolvers[2].process(i_right, o_2);
let o_3 = &mut output.channel_data_mut(3)[..];
let _ = convolvers[3].process(i_right, o_3);
let o_2 = output.channel_data(2).clone();
let o_3 = output.channel_data(3).clone();
output
.channel_data_mut(0)
.iter_mut()
.zip(o_2.iter())
.for_each(|(l, sl)| *l += *sl);
output
.channel_data_mut(1)
.iter_mut()
.zip(o_3.iter())
.for_each(|(r, sr)| *r += *sr);
output.set_number_of_channels(2);
}
(1, 4) => {
output.set_number_of_channels(4);
let i = &input.channel_data(0)[..];
let o_0 = &mut output.channel_data_mut(0)[..];
let _ = convolvers[0].process(i, o_0);
let o_1 = &mut output.channel_data_mut(1)[..];
let _ = convolvers[1].process(i, o_1);
let o_2 = &mut output.channel_data_mut(2)[..];
let _ = convolvers[2].process(i, o_2);
let o_3 = &mut output.channel_data_mut(3)[..];
let _ = convolvers[3].process(i, o_3);
let o_2 = output.channel_data(2).clone();
let o_3 = output.channel_data(3).clone();
output
.channel_data_mut(0)
.iter_mut()
.zip(o_2.iter())
.for_each(|(l, sl)| *l += *sl);
output
.channel_data_mut(1)
.iter_mut()
.zip(o_3.iter())
.for_each(|(r, sr)| *r += *sr);
output.set_number_of_channels(2);
}
_ => unreachable!(),
}
if input.is_silent() {
self.tail_count += RENDER_QUANTUM_SIZE;
return self.tail_count < self.impulse_length;
}
self.tail_count = 0;
true
}
fn onmessage(&mut self, msg: &mut dyn Any) {
if let Some(msg) = msg.downcast_mut::<ConvolverInfosMessage>() {
let ConvolverInfosMessage {
convolvers,
impulse_length,
impulse_number_of_channels,
} = msg;
std::mem::swap(&mut self.convolvers, convolvers);
self.impulse_length = *impulse_length;
self.impulse_number_of_channels = *impulse_number_of_channels;
return;
}
log::warn!("ConvolverRenderer: Dropping incoming message {msg:?}");
}
}
#[cfg(test)]
mod tests {
use float_eq::assert_float_eq;
use crate::context::{BaseAudioContext, OfflineAudioContext};
use crate::node::{AudioBufferSourceNode, AudioBufferSourceOptions, AudioScheduledSourceNode};
use super::*;
#[test]
#[should_panic]
fn test_buffer_sample_rate_matches() {
let context = OfflineAudioContext::new(1, 128, 44100.);
let ir = vec![1.];
let ir = AudioBuffer::from(vec![ir; 1], 48000.); let options = ConvolverOptions {
buffer: Some(ir),
..ConvolverOptions::default()
};
let _ = ConvolverNode::new(&context, options);
}
#[test]
#[should_panic]
fn test_buffer_must_have_1_2_4_channels() {
let context = OfflineAudioContext::new(1, 128, 48000.);
let ir = vec![1.];
let ir = AudioBuffer::from(vec![ir; 3], 48000.); let options = ConvolverOptions {
buffer: Some(ir),
..ConvolverOptions::default()
};
let _ = ConvolverNode::new(&context, options);
}
#[test]
fn test_constructor_options_buffer() {
let sample_rate = 44100.;
let mut context = OfflineAudioContext::new(1, 10, sample_rate);
let ir = vec![1.];
let calibration = 0.00125;
let channel_data = vec![0., 1., 0., -1., 0.];
let expected = [0., calibration, 0., -calibration, 0., 0., 0., 0., 0., 0.];
let ir = AudioBuffer::from(vec![ir; 1], sample_rate);
let options = ConvolverOptions {
buffer: Some(ir),
..ConvolverOptions::default()
};
let conv = ConvolverNode::new(&context, options);
conv.connect(&context.destination());
let buffer = AudioBuffer::from(vec![channel_data; 1], sample_rate);
let mut src = context.create_buffer_source();
src.connect(&conv);
src.set_buffer(buffer);
src.start();
let output = context.start_rendering_sync();
assert_float_eq!(output.get_channel_data(0), &expected[..], abs_all <= 1E-6);
}
fn test_convolve(signal: &[f32], impulse_resp: Option<Vec<f32>>, length: usize) -> AudioBuffer {
let sample_rate = 44100.;
let mut context = OfflineAudioContext::new(1, length, sample_rate);
let input = AudioBuffer::from(vec![signal.to_vec()], sample_rate);
let mut src = AudioBufferSourceNode::new(&context, AudioBufferSourceOptions::default());
src.set_buffer(input);
src.start();
let mut conv = ConvolverNode::new(&context, ConvolverOptions::default());
if let Some(ir) = impulse_resp {
conv.set_buffer(AudioBuffer::from(vec![ir.to_vec()], sample_rate));
}
src.connect(&conv);
conv.connect(&context.destination());
context.start_rendering_sync()
}
#[test]
fn test_passthrough() {
let output = test_convolve(&[0., 1., 0., -1., 0.], None, 10);
let expected = [0., 1., 0., -1., 0., 0., 0., 0., 0., 0.];
assert_float_eq!(output.get_channel_data(0), &expected[..], abs_all <= 1E-6);
}
#[test]
fn test_empty() {
let ir = vec![];
let output = test_convolve(&[0., 1., 0., -1., 0.], Some(ir), 10);
let expected = [0.; 10];
assert_float_eq!(output.get_channel_data(0), &expected[..], abs_all <= 1E-6);
}
#[test]
fn test_zeroed() {
let ir = vec![0., 0., 0., 0., 0., 0.];
let output = test_convolve(&[0., 1., 0., -1., 0.], Some(ir), 10);
let expected = [0.; 10];
assert_float_eq!(output.get_channel_data(0), &expected[..], abs_all <= 1E-6);
}
#[test]
fn test_identity() {
let ir = vec![1.];
let calibration = 0.00125;
let output = test_convolve(&[0., 1., 0., -1., 0.], Some(ir), 10);
let expected = [0., calibration, 0., -calibration, 0., 0., 0., 0., 0., 0.];
assert_float_eq!(output.get_channel_data(0), &expected[..], abs_all <= 1E-6);
}
#[test]
fn test_two_id() {
let ir = vec![1., 1.];
let calibration = 0.00125;
let output = test_convolve(&[0., 1., 0., -1., 0.], Some(ir), 10);
let expected = [
0.,
calibration,
calibration,
-calibration,
-calibration,
0.,
0.,
0.,
0.,
0.,
];
assert_float_eq!(output.get_channel_data(0), &expected[..], abs_all <= 1E-6);
}
#[test]
fn test_should_have_tail_time() {
const IR_LEN: usize = 256;
let ir = vec![1.; IR_LEN];
let input = &[1.];
let output = test_convolve(input, Some(ir), 512);
let output = output.channel_data(0).as_slice();
assert!(!output[..IR_LEN].iter().any(|v| *v <= 1E-6));
assert_float_eq!(&output[IR_LEN..], &[0.; 512 - IR_LEN][..], abs_all <= 1E-6);
}
#[test]
fn test_channel_config_1_chan_in_1_chan_ir() {
let number_of_channels = 1;
let length = 128;
let sample_rate = 44100.;
let mut context = OfflineAudioContext::new(number_of_channels, length, sample_rate);
let input = AudioBuffer::from(vec![vec![1.]], sample_rate);
let ir = AudioBuffer::from(vec![vec![0., 1.]], sample_rate);
let mut src = AudioBufferSourceNode::new(
&context,
AudioBufferSourceOptions {
buffer: Some(input),
..AudioBufferSourceOptions::default()
},
);
let conv = ConvolverNode::new(
&context,
ConvolverOptions {
buffer: Some(ir),
disable_normalization: true,
..ConvolverOptions::default()
},
);
src.connect(&conv);
conv.connect(&context.destination());
src.start();
let result = context.start_rendering_sync();
let mut expected = [0.; 128];
expected[1] = 1.;
assert_float_eq!(
result.get_channel_data(0)[..],
expected[..],
abs_all <= 1e-7
);
}
#[test]
fn test_channel_config_1_chan_in_2_chan_ir() {
let number_of_channels = 2;
let length = 128;
let sample_rate = 44100.;
let mut context = OfflineAudioContext::new(number_of_channels, length, sample_rate);
let input = AudioBuffer::from(vec![vec![1.]], sample_rate);
let ir = AudioBuffer::from(vec![vec![0., 1., 0.], vec![0., 0., 1.]], sample_rate);
let mut src = AudioBufferSourceNode::new(
&context,
AudioBufferSourceOptions {
buffer: Some(input),
..AudioBufferSourceOptions::default()
},
);
let conv = ConvolverNode::new(
&context,
ConvolverOptions {
buffer: Some(ir),
disable_normalization: true,
..ConvolverOptions::default()
},
);
src.connect(&conv);
conv.connect(&context.destination());
src.start();
let result = context.start_rendering_sync();
let mut expected_left = [0.; 128];
expected_left[1] = 1.;
let mut expected_right = [0.; 128];
expected_right[2] = 1.;
assert_eq!(result.number_of_channels(), 2);
assert_float_eq!(
result.get_channel_data(0)[..],
expected_left[..],
abs_all <= 1e-7
);
assert_float_eq!(
result.get_channel_data(1)[..],
expected_right[..],
abs_all <= 1e-7
);
}
#[test]
fn test_channel_config_2_chan_in_1_chan_ir() {
let number_of_channels = 2;
let length = 128;
let sample_rate = 44100.;
let mut context = OfflineAudioContext::new(number_of_channels, length, sample_rate);
let input = AudioBuffer::from(vec![vec![1., 0.], vec![0., 1.]], sample_rate);
let ir = AudioBuffer::from(vec![vec![0., 1.]], sample_rate);
let mut src = AudioBufferSourceNode::new(
&context,
AudioBufferSourceOptions {
buffer: Some(input),
..AudioBufferSourceOptions::default()
},
);
let conv = ConvolverNode::new(
&context,
ConvolverOptions {
buffer: Some(ir),
disable_normalization: true,
..ConvolverOptions::default()
},
);
src.connect(&conv);
conv.connect(&context.destination());
src.start();
let result = context.start_rendering_sync();
let mut expected_left = [0.; 128];
expected_left[1] = 1.;
let mut expected_right = [0.; 128];
expected_right[2] = 1.;
assert_eq!(result.number_of_channels(), 2);
assert_float_eq!(
result.get_channel_data(0)[..],
expected_left[..],
abs_all <= 1e-7
);
assert_float_eq!(
result.get_channel_data(1)[..],
expected_right[..],
abs_all <= 1e-7
);
}
#[test]
fn test_channel_config_2_chan_in_2_chan_ir() {
let number_of_channels = 2;
let length = 128;
let sample_rate = 44100.;
let mut context = OfflineAudioContext::new(number_of_channels, length, sample_rate);
let input = AudioBuffer::from(vec![vec![1., 0.], vec![0., 1.]], sample_rate);
let ir = AudioBuffer::from(vec![vec![0., 1., 0.], vec![0., 0., 1.]], sample_rate);
let mut src = AudioBufferSourceNode::new(
&context,
AudioBufferSourceOptions {
buffer: Some(input),
..AudioBufferSourceOptions::default()
},
);
let conv = ConvolverNode::new(
&context,
ConvolverOptions {
buffer: Some(ir),
disable_normalization: true,
..ConvolverOptions::default()
},
);
src.connect(&conv);
conv.connect(&context.destination());
src.start();
let result = context.start_rendering_sync();
let mut expected_left = [0.; 128];
expected_left[1] = 1.;
let mut expected_right = [0.; 128];
expected_right[3] = 1.;
assert_eq!(result.number_of_channels(), 2);
assert_float_eq!(
result.get_channel_data(0)[..],
expected_left[..],
abs_all <= 1e-7
);
assert_float_eq!(
result.get_channel_data(1)[..],
expected_right[..],
abs_all <= 1e-7
);
}
#[test]
fn test_channel_config_2_chan_in_4_chan_ir() {
let number_of_channels = 2;
let length = 128;
let sample_rate = 44100.;
let mut context = OfflineAudioContext::new(number_of_channels, length, sample_rate);
let input = AudioBuffer::from(vec![vec![1., 0.], vec![0., 1.]], sample_rate);
let ir = AudioBuffer::from(
vec![
vec![0., 1., 0., 0., 0.], vec![0., 0., 1., 0., 0.], vec![0., 0., 0., 1., 0.], vec![0., 0., 0., 0., 1.], ],
sample_rate,
);
let mut src = AudioBufferSourceNode::new(
&context,
AudioBufferSourceOptions {
buffer: Some(input),
..AudioBufferSourceOptions::default()
},
);
let conv = ConvolverNode::new(
&context,
ConvolverOptions {
buffer: Some(ir),
disable_normalization: true,
..ConvolverOptions::default()
},
);
src.connect(&conv);
conv.connect(&context.destination());
src.start();
let result = context.start_rendering_sync();
let mut expected_left = [0.; 128];
expected_left[1] = 1.;
expected_left[4] = 1.;
let mut expected_right = [0.; 128];
expected_right[2] = 1.;
expected_right[5] = 1.;
assert_eq!(result.number_of_channels(), 2);
assert_float_eq!(
result.get_channel_data(0)[..],
expected_left[..],
abs_all <= 1e-7
);
assert_float_eq!(
result.get_channel_data(1)[..],
expected_right[..],
abs_all <= 1e-7
);
}
#[test]
fn test_channel_config_1_chan_in_4_chan_ir() {
let number_of_channels = 2;
let length = 128;
let sample_rate = 44100.;
let mut context = OfflineAudioContext::new(number_of_channels, length, sample_rate);
let input = AudioBuffer::from(vec![vec![1., 0.]], sample_rate);
let ir = AudioBuffer::from(
vec![
vec![0., 1., 0., 0., 0.], vec![0., 0., 1., 0., 0.], vec![0., 0., 0., 1., 0.], vec![0., 0., 0., 0., 1.], ],
sample_rate,
);
let mut src = AudioBufferSourceNode::new(
&context,
AudioBufferSourceOptions {
buffer: Some(input),
..AudioBufferSourceOptions::default()
},
);
let conv = ConvolverNode::new(
&context,
ConvolverOptions {
buffer: Some(ir),
disable_normalization: true,
..ConvolverOptions::default()
},
);
src.connect(&conv);
conv.connect(&context.destination());
src.start();
let result = context.start_rendering_sync();
let mut expected_left = [0.; 128];
expected_left[1] = 1.;
expected_left[3] = 1.;
let mut expected_right = [0.; 128];
expected_right[2] = 1.;
expected_right[4] = 1.;
assert_eq!(result.number_of_channels(), 2);
assert_float_eq!(
result.get_channel_data(0)[..],
expected_left[..],
abs_all <= 1e-7
);
assert_float_eq!(
result.get_channel_data(1)[..],
expected_right[..],
abs_all <= 1e-7
);
}
}