firewheel_nodes/
convolution.rs

1use core::f32;
2
3use fft_convolver::FFTConvolver;
4use firewheel_core::{
5    channel_config::{ChannelConfig, ChannelCount},
6    collector::OwnedGc,
7    diff::{Diff, Patch},
8    dsp::{
9        declick::{DeclickFadeCurve, Declicker},
10        fade::FadeCurve,
11        filter::smoothing_filter::DEFAULT_SMOOTH_SECONDS,
12        mix::{Mix, MixDSP},
13        volume::Volume,
14    },
15    event::NodeEventType,
16    node::{
17        AudioNode, AudioNodeInfo, AudioNodeProcessor, ConstructProcessorContext, ProcessStatus,
18    },
19    param::smoother::{SmoothedParam, SmootherConfig},
20    sample_resource::SampleResourceF32,
21};
22
23/// Imparts characteristics of an [`ImpulseResponse`] to the input signal.
24///
25/// Convolution is often used to achieve reverb effects, but is more
26/// computationally expensive than algorithmic reverb.
27#[derive(Clone, Copy, Patch, Diff, PartialEq)]
28#[cfg_attr(feature = "bevy", derive(bevy_ecs::prelude::Component))]
29pub struct ConvolutionNode<const CHANNELS: usize> {
30    /// Pause the convolution processing.
31    ///
32    /// This prevents a tail from ringing out when you want all sound to
33    /// momentarily pause.
34    pub pause: bool,
35
36    /// The value representing the mix between the two audio signals
37    ///
38    /// This is a normalized value in the range `[0.0, 1.0]`, where `0.0` is
39    /// fully the first signal, `1.0` is fully the second signal, and `0.5` is
40    /// an equal mix of both.
41    ///
42    /// By default this is set to [`Mix::CENTER`].
43    pub mix: Mix,
44
45    /// The algorithm used to map the normalized mix value in the range `[0.0,
46    /// 1.0]` to the corresponding gain values for the two signals.
47    ///
48    /// By default this is set to [`FadeCurve::EqualPower3dB`].
49    pub fade_curve: FadeCurve,
50
51    /// The gain applied to the resulting convolved signal.
52    ///
53    /// Defaults to -20dB to balance the volume increase likely to occur when
54    /// convolving audio. Values closer to 1.0 may be very loud.
55    pub wet_gain: Volume,
56
57    /// Adjusts the time in seconds over which parameters are smoothed for `mix`
58    /// and `wet_gain`.
59    ///
60    /// Defaults to `0.015` (15ms).
61    pub smooth_seconds: f32,
62}
63
64/// Node configuration for [`ConvolutionNode`].
65#[derive(Debug, Clone, Copy, PartialEq, Eq)]
66#[cfg_attr(feature = "bevy", derive(bevy_ecs::prelude::Component))]
67#[cfg_attr(feature = "bevy_reflect", derive(bevy_reflect::Reflect))]
68pub struct ConvolutionNodeConfig<const CHANNELS: usize> {
69    /// The maximum number of supported IR channels (must be
70    /// `ChannelCount::MONO` or `ChannelCount::STEREO`). This determines the
71    /// number of buffers allocated. Loading an impulse response with more
72    /// channels than supported will result in the remaining channels being
73    /// removed.
74    pub max_impulse_channel_count: ChannelCount,
75
76    pub partition_size: usize,
77}
78
79/// The default partition size to use with a [`ConvolutionNode`].
80///
81/// Smaller blocks may reduce latency at the cost of increased CPU usage.
82pub const DEFAULT_PARTITION_SIZE: usize = 1024;
83
84/// A processed impulse response sample.
85///
86/// `ImpulseResponse`s are used in [`ConvolutionNode`]s.
87pub struct ImpulseResponse(Vec<FFTConvolver<f32>>);
88
89impl ImpulseResponse {
90    /// Create a new `ImpulseResponse` with a custom partition size.
91    ///
92    /// Smaller blocks may reduce latency at the cost of increased CPU usage.
93    pub fn new_with_partition_size(sample: impl SampleResourceF32, partition_size: usize) -> Self {
94        let num_channels = sample.num_channels().get();
95        Self(
96            (0..num_channels)
97                .map(|channel_index| {
98                    let mut conv = FFTConvolver::default();
99                    // The sample channel must exist, as our iterator is based
100                    // on its length. The FFT may error, depending on several
101                    // factors. Currently, this will result in a panic.
102                    conv.init(partition_size, sample.channel(channel_index).unwrap())
103                        .unwrap();
104                    conv
105                })
106                .collect(),
107        )
108    }
109
110    /// Create a new `ImpulseResponse` with a default partition size of `1024`.
111    pub fn new(sample: impl SampleResourceF32) -> Self {
112        Self::new_with_partition_size(sample, DEFAULT_PARTITION_SIZE)
113    }
114}
115
116impl<const CHANNELS: usize> Default for ConvolutionNodeConfig<CHANNELS> {
117    fn default() -> Self {
118        Self {
119            // A Convolution node with 0 `CHANNELS` is invalid and will panic.
120            max_impulse_channel_count: ChannelCount::new(CHANNELS as u32).unwrap(),
121            partition_size: DEFAULT_PARTITION_SIZE,
122        }
123    }
124}
125
126impl<const CHANNELS: usize> Default for ConvolutionNode<CHANNELS> {
127    fn default() -> Self {
128        Self {
129            mix: Mix::CENTER,
130            fade_curve: FadeCurve::default(),
131            wet_gain: Volume::Decibels(-20.0),
132            pause: false,
133            smooth_seconds: DEFAULT_SMOOTH_SECONDS,
134        }
135    }
136}
137
138impl<const CHANNELS: usize> AudioNode for ConvolutionNode<CHANNELS> {
139    type Configuration = ConvolutionNodeConfig<CHANNELS>;
140
141    fn info(&self, _configuration: &Self::Configuration) -> AudioNodeInfo {
142        if CHANNELS > 2 {
143            panic!(
144                "ConvolutionNode::CHANNELS cannot be greater than 2, got {}",
145                CHANNELS
146            );
147        }
148        AudioNodeInfo::new()
149            .debug_name("convolution")
150            .channel_config(ChannelConfig::new(CHANNELS, CHANNELS))
151    }
152
153    fn construct_processor(
154        &self,
155        _configuration: &Self::Configuration,
156        cx: ConstructProcessorContext,
157    ) -> impl AudioNodeProcessor {
158        let sample_rate = cx.stream_info.sample_rate;
159        let smooth_config = SmootherConfig {
160            smooth_seconds: self.smooth_seconds,
161            ..Default::default()
162        };
163        ConvolutionProcessor::<CHANNELS> {
164            params: self.clone(),
165            mix: MixDSP::new(self.mix, self.fade_curve, smooth_config, sample_rate),
166            wet_gain_smoothed: SmoothedParam::new(self.wet_gain.amp(), smooth_config, sample_rate),
167            declick: Declicker::default(),
168            impulse_response: OwnedGc::new(None),
169            next_impulse_response: OwnedGc::new(None),
170        }
171    }
172}
173
174pub enum ConvolutionNodeEvent {
175    SetImpulseResponse(Option<ImpulseResponse>),
176}
177
178struct ConvolutionProcessor<const CHANNELS: usize> {
179    params: ConvolutionNode<CHANNELS>,
180    mix: MixDSP,
181    wet_gain_smoothed: SmoothedParam,
182    declick: Declicker,
183    impulse_response: OwnedGc<Option<ImpulseResponse>>,
184    // We cannot be certain that the transition to a new impulse response will
185    // happen within one block, so we must store the old impulse response until
186    // the declicker settles.
187    next_impulse_response: OwnedGc<Option<ImpulseResponse>>,
188}
189
190impl<const CHANNELS: usize> AudioNodeProcessor for ConvolutionProcessor<CHANNELS> {
191    fn process(
192        &mut self,
193        info: &firewheel_core::node::ProcInfo,
194        buffers: firewheel_core::node::ProcBuffers,
195        events: &mut firewheel_core::event::ProcEvents,
196        extra: &mut firewheel_core::node::ProcExtra,
197    ) -> ProcessStatus {
198        for mut event in events.drain() {
199            match event {
200                NodeEventType::Param { data, path } => {
201                    if let Ok(patch) = ConvolutionNode::<CHANNELS>::patch(&data, &path) {
202                        // You can match on the patch directly
203                        match patch {
204                            ConvolutionNodePatch::Mix(mix) => {
205                                self.mix.set_mix(mix, self.params.fade_curve);
206                            }
207                            ConvolutionNodePatch::FadeCurve(curve) => {
208                                self.mix.set_mix(self.params.mix, curve);
209                            }
210                            ConvolutionNodePatch::WetGain(gain) => {
211                                self.wet_gain_smoothed.set_value(gain.amp());
212                            }
213                            ConvolutionNodePatch::Pause(pause) => {
214                                self.declick.fade_to_enabled(!pause, &extra.declick_values);
215                            }
216                            ConvolutionNodePatch::SmoothSeconds(smooth_seconds) => {
217                                self.mix = MixDSP::new(
218                                    self.params.mix,
219                                    self.params.fade_curve,
220                                    SmootherConfig {
221                                        smooth_seconds,
222                                        ..Default::default()
223                                    },
224                                    info.sample_rate,
225                                );
226                                self.wet_gain_smoothed
227                                    .set_smooth_seconds(smooth_seconds, info.sample_rate);
228                            }
229                        }
230                        self.params.apply(patch);
231                    }
232                }
233                NodeEventType::Custom(_) => {
234                    if event.downcast_into_owned(&mut self.next_impulse_response) {
235                        // Disable the audio stream while changing IRs
236                        self.declick.fade_to_0(&extra.declick_values);
237                    }
238                }
239                _ => (),
240            }
241        }
242
243        // Check to see if there is a new IR waiting. If there is, and the audio
244        // has stopped, swap the IR, and continue
245        if self.next_impulse_response.is_some() && self.declick == Declicker::SettledAt0 {
246            // The next impulse result must exist due to the check in this block
247            let next_impulse_response = self.next_impulse_response.take().unwrap();
248            self.impulse_response.replace(next_impulse_response);
249            // Don't unpause if we're paused manually
250            if !self.params.pause {
251                self.declick.fade_to_1(&extra.declick_values);
252            }
253            // Begin mixing back in with the new impulse response next block
254            return ProcessStatus::ClearAllOutputs;
255        }
256
257        // Only process if an impulse response is supplied
258        if self.impulse_response.is_some() {
259            const WET_GAIN_BUFFER: usize = 0;
260            let wet_gain_buffer = &mut extra.scratch_buffers.channels_mut::<1>()[WET_GAIN_BUFFER];
261
262            // Amount to scale based on wet signal gain
263            self.wet_gain_smoothed.process_into_buffer(wet_gain_buffer);
264
265            // If paused, return early after processing wet gain buffers to
266            // avoid clicking
267            if self.params.pause && self.declick == Declicker::SettledAt0 {
268                return ProcessStatus::ClearAllOutputs;
269            }
270
271            for (input_index, input) in buffers.inputs.iter().enumerate() {
272                // We unfortunately can't add more buffers to the convolution
273                // struct, as we don't own it. This means we can't do stereo
274                // with a mono impulse response. In this case, we'll just pass
275                // the input through if we can't get a channel.
276
277                // We already checked that the impulse response must exist, so
278                // we can safely unwrap.
279                if let Some(conv) = self
280                    .impulse_response
281                    .get_mut()
282                    .as_mut()
283                    .unwrap()
284                    .0
285                    .get_mut(input_index)
286                {
287                    conv.process(input, buffers.outputs[input_index]).unwrap();
288
289                    // Apply wet signal gain
290                    for (output_sample, gain) in buffers.outputs[input_index]
291                        .iter_mut()
292                        .zip(wet_gain_buffer.iter())
293                    {
294                        *output_sample *= gain;
295                    }
296                }
297            }
298        }
299
300        if self.impulse_response.is_some() {
301            match CHANNELS {
302                1 => {
303                    self.mix.mix_dry_into_wet_mono(
304                        buffers.inputs[0],
305                        buffers.outputs[0],
306                        info.frames,
307                    );
308                }
309                2 => {
310                    let (left, right) = buffers.outputs.split_at_mut(1);
311                    self.mix.mix_dry_into_wet_stereo(
312                        buffers.inputs[0],
313                        buffers.inputs[1],
314                        left[0],
315                        right[0],
316                        info.frames,
317                    );
318                }
319                _ => panic!("Only Mono and Stereo are supported"),
320            }
321        } else {
322            // Pass through audio if no impulse provided
323            for (input, output) in buffers.inputs.iter().zip(buffers.outputs.iter_mut()) {
324                output.copy_from_slice(input);
325            }
326        }
327
328        self.declick.process(
329            buffers.outputs,
330            0..info.frames,
331            &extra.declick_values,
332            1.0,
333            DeclickFadeCurve::EqualPower3dB,
334        );
335
336        buffers.check_for_silence_on_outputs(f32::EPSILON)
337    }
338}
339
340#[cfg(test)]
341mod tests {
342    use super::*;
343
344    // Behave as expected up to stereo
345    #[test]
346    fn mono_stereo_ok() {
347        ConvolutionNode::<1>::default().info(&ConvolutionNodeConfig::default());
348        ConvolutionNode::<2>::default().info(&ConvolutionNodeConfig::default());
349    }
350
351    // Error when 3+ channels are requested
352    #[test]
353    #[should_panic]
354    fn fail_above_stereo() {
355        ConvolutionNode::<3>::default().info(&ConvolutionNodeConfig::default());
356    }
357}