firewheel_nodes/
convolution.rs

1use core::f32;
2
3use fft_convolver::FFTConvolver;
4use firewheel_core::{
5    channel_config::{ChannelConfig, ChannelCount},
6    collector::OwnedGc,
7    diff::{Diff, Patch},
8    dsp::{
9        declick::{DeclickFadeCurve, Declicker},
10        fade::FadeCurve,
11        filter::smoothing_filter::DEFAULT_SMOOTH_SECONDS,
12        mix::{Mix, MixDSP},
13        volume::Volume,
14    },
15    event::NodeEventType,
16    node::{
17        AudioNode, AudioNodeInfo, AudioNodeProcessor, ConstructProcessorContext, ProcessStatus,
18    },
19    param::smoother::{SmoothedParam, SmootherConfig},
20    sample_resource::SampleResourceF32,
21};
22
23/// Imparts characteristics of an [`ImpulseResponse`] to the input signal.
24///
25/// Convolution is often used to achieve reverb effects, but is more
26/// computationally expensive than algorithmic reverb.
27#[derive(Patch, Diff, Debug, Clone, Copy, PartialEq)]
28#[cfg_attr(feature = "bevy", derive(bevy_ecs::prelude::Component))]
29#[cfg_attr(feature = "bevy_reflect", derive(bevy_reflect::Reflect))]
30#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
31pub struct ConvolutionNode<const CHANNELS: usize> {
32    /// Pause the convolution processing.
33    ///
34    /// This prevents a tail from ringing out when you want all sound to
35    /// momentarily pause.
36    pub pause: bool,
37
38    /// The value representing the mix between the two audio signals
39    ///
40    /// This is a normalized value in the range `[0.0, 1.0]`, where `0.0` is
41    /// fully the first signal, `1.0` is fully the second signal, and `0.5` is
42    /// an equal mix of both.
43    ///
44    /// By default this is set to [`Mix::CENTER`].
45    pub mix: Mix,
46
47    /// The algorithm used to map the normalized mix value in the range `[0.0,
48    /// 1.0]` to the corresponding gain values for the two signals.
49    ///
50    /// By default this is set to [`FadeCurve::EqualPower3dB`].
51    pub fade_curve: FadeCurve,
52
53    /// The gain applied to the resulting convolved signal.
54    ///
55    /// Defaults to -20dB to balance the volume increase likely to occur when
56    /// convolving audio. Values closer to 1.0 may be very loud.
57    pub wet_gain: Volume,
58
59    /// Adjusts the time in seconds over which parameters are smoothed for `mix`
60    /// and `wet_gain`.
61    ///
62    /// Defaults to `0.015` (15ms).
63    pub smooth_seconds: f32,
64}
65
66/// Node configuration for [`ConvolutionNode`].
67#[derive(Debug, Clone, Copy, PartialEq, Eq)]
68#[cfg_attr(feature = "bevy", derive(bevy_ecs::prelude::Component))]
69#[cfg_attr(feature = "bevy_reflect", derive(bevy_reflect::Reflect))]
70#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
71pub struct ConvolutionNodeConfig<const CHANNELS: usize> {
72    /// The maximum number of supported IR channels (must be
73    /// `ChannelCount::MONO` or `ChannelCount::STEREO`). This determines the
74    /// number of buffers allocated. Loading an impulse response with more
75    /// channels than supported will result in the remaining channels being
76    /// removed.
77    pub max_impulse_channel_count: ChannelCount,
78
79    pub partition_size: usize,
80}
81
82/// The default partition size to use with a [`ConvolutionNode`].
83///
84/// Smaller blocks may reduce latency at the cost of increased CPU usage.
85pub const DEFAULT_PARTITION_SIZE: usize = 1024;
86
87/// A processed impulse response sample.
88///
89/// `ImpulseResponse`s are used in [`ConvolutionNode`]s.
90pub struct ImpulseResponse(Vec<FFTConvolver<f32>>);
91
92impl ImpulseResponse {
93    /// Create a new `ImpulseResponse` with a custom partition size.
94    ///
95    /// Smaller blocks may reduce latency at the cost of increased CPU usage.
96    pub fn new_with_partition_size(sample: impl SampleResourceF32, partition_size: usize) -> Self {
97        let num_channels = sample.num_channels().get();
98        Self(
99            (0..num_channels)
100                .map(|channel_index| {
101                    let mut conv = FFTConvolver::default();
102                    // The sample channel must exist, as our iterator is based
103                    // on its length. The FFT may error, depending on several
104                    // factors. Currently, this will result in a panic.
105                    conv.init(partition_size, sample.channel(channel_index).unwrap())
106                        .unwrap();
107                    conv
108                })
109                .collect(),
110        )
111    }
112
113    /// Create a new `ImpulseResponse` with a default partition size of `1024`.
114    pub fn new(sample: impl SampleResourceF32) -> Self {
115        Self::new_with_partition_size(sample, DEFAULT_PARTITION_SIZE)
116    }
117}
118
119impl<const CHANNELS: usize> Default for ConvolutionNodeConfig<CHANNELS> {
120    fn default() -> Self {
121        Self {
122            // A Convolution node with 0 `CHANNELS` is invalid and will panic.
123            max_impulse_channel_count: ChannelCount::new(CHANNELS as u32).unwrap(),
124            partition_size: DEFAULT_PARTITION_SIZE,
125        }
126    }
127}
128
129impl<const CHANNELS: usize> Default for ConvolutionNode<CHANNELS> {
130    fn default() -> Self {
131        Self {
132            mix: Mix::CENTER,
133            fade_curve: FadeCurve::default(),
134            wet_gain: Volume::Decibels(-20.0),
135            pause: false,
136            smooth_seconds: DEFAULT_SMOOTH_SECONDS,
137        }
138    }
139}
140
141impl<const CHANNELS: usize> AudioNode for ConvolutionNode<CHANNELS> {
142    type Configuration = ConvolutionNodeConfig<CHANNELS>;
143
144    fn info(&self, _configuration: &Self::Configuration) -> AudioNodeInfo {
145        if CHANNELS > 2 {
146            panic!(
147                "ConvolutionNode::CHANNELS cannot be greater than 2, got {}",
148                CHANNELS
149            );
150        }
151        AudioNodeInfo::new()
152            .debug_name("convolution")
153            .channel_config(ChannelConfig::new(CHANNELS, CHANNELS))
154    }
155
156    fn construct_processor(
157        &self,
158        _configuration: &Self::Configuration,
159        cx: ConstructProcessorContext,
160    ) -> impl AudioNodeProcessor {
161        let sample_rate = cx.stream_info.sample_rate;
162        let smooth_config = SmootherConfig {
163            smooth_seconds: self.smooth_seconds,
164            ..Default::default()
165        };
166        ConvolutionProcessor::<CHANNELS> {
167            params: self.clone(),
168            mix: MixDSP::new(self.mix, self.fade_curve, smooth_config, sample_rate),
169            wet_gain_smoothed: SmoothedParam::new(self.wet_gain.amp(), smooth_config, sample_rate),
170            declick: Declicker::default(),
171            impulse_response: OwnedGc::new(None),
172            next_impulse_response: OwnedGc::new(None),
173        }
174    }
175}
176
177pub enum ConvolutionNodeEvent {
178    SetImpulseResponse(Option<ImpulseResponse>),
179}
180
181struct ConvolutionProcessor<const CHANNELS: usize> {
182    params: ConvolutionNode<CHANNELS>,
183    mix: MixDSP,
184    wet_gain_smoothed: SmoothedParam,
185    declick: Declicker,
186    impulse_response: OwnedGc<Option<ImpulseResponse>>,
187    // We cannot be certain that the transition to a new impulse response will
188    // happen within one block, so we must store the old impulse response until
189    // the declicker settles.
190    next_impulse_response: OwnedGc<Option<ImpulseResponse>>,
191}
192
193impl<const CHANNELS: usize> AudioNodeProcessor for ConvolutionProcessor<CHANNELS> {
194    fn process(
195        &mut self,
196        info: &firewheel_core::node::ProcInfo,
197        buffers: firewheel_core::node::ProcBuffers,
198        events: &mut firewheel_core::event::ProcEvents,
199        extra: &mut firewheel_core::node::ProcExtra,
200    ) -> ProcessStatus {
201        for mut event in events.drain() {
202            match event {
203                NodeEventType::Param { data, path } => {
204                    if let Ok(patch) = ConvolutionNode::<CHANNELS>::patch(&data, &path) {
205                        // You can match on the patch directly
206                        match patch {
207                            ConvolutionNodePatch::Mix(mix) => {
208                                self.mix.set_mix(mix, self.params.fade_curve);
209                            }
210                            ConvolutionNodePatch::FadeCurve(curve) => {
211                                self.mix.set_mix(self.params.mix, curve);
212                            }
213                            ConvolutionNodePatch::WetGain(gain) => {
214                                self.wet_gain_smoothed.set_value(gain.amp());
215                            }
216                            ConvolutionNodePatch::Pause(pause) => {
217                                self.declick.fade_to_enabled(!pause, &extra.declick_values);
218                            }
219                            ConvolutionNodePatch::SmoothSeconds(smooth_seconds) => {
220                                self.mix = MixDSP::new(
221                                    self.params.mix,
222                                    self.params.fade_curve,
223                                    SmootherConfig {
224                                        smooth_seconds,
225                                        ..Default::default()
226                                    },
227                                    info.sample_rate,
228                                );
229                                self.wet_gain_smoothed
230                                    .set_smooth_seconds(smooth_seconds, info.sample_rate);
231                            }
232                        }
233                        self.params.apply(patch);
234                    }
235                }
236                NodeEventType::Custom(_) => {
237                    if event.downcast_into_owned(&mut self.next_impulse_response) {
238                        // Disable the audio stream while changing IRs
239                        self.declick.fade_to_0(&extra.declick_values);
240                    }
241                }
242                _ => (),
243            }
244        }
245
246        // Check to see if there is a new IR waiting. If there is, and the audio
247        // has stopped, swap the IR, and continue
248        if self.next_impulse_response.is_some() && self.declick == Declicker::SettledAt0 {
249            // The next impulse result must exist due to the check in this block
250            let next_impulse_response = self.next_impulse_response.take().unwrap();
251            self.impulse_response.replace(next_impulse_response);
252            // Don't unpause if we're paused manually
253            if !self.params.pause {
254                self.declick.fade_to_1(&extra.declick_values);
255            }
256            // Begin mixing back in with the new impulse response next block
257            return ProcessStatus::ClearAllOutputs;
258        }
259
260        // Only process if an impulse response is supplied
261        if self.impulse_response.is_some() {
262            const WET_GAIN_BUFFER: usize = 0;
263            let wet_gain_buffer = &mut extra.scratch_buffers.channels_mut::<1>()[WET_GAIN_BUFFER];
264
265            // Amount to scale based on wet signal gain
266            self.wet_gain_smoothed.process_into_buffer(wet_gain_buffer);
267
268            // If paused, return early after processing wet gain buffers to
269            // avoid clicking
270            if self.params.pause && self.declick == Declicker::SettledAt0 {
271                return ProcessStatus::ClearAllOutputs;
272            }
273
274            for (input_index, input) in buffers.inputs.iter().enumerate() {
275                // We unfortunately can't add more buffers to the convolution
276                // struct, as we don't own it. This means we can't do stereo
277                // with a mono impulse response. In this case, we'll just pass
278                // the input through if we can't get a channel.
279
280                // We already checked that the impulse response must exist, so
281                // we can safely unwrap.
282                if let Some(conv) = self
283                    .impulse_response
284                    .get_mut()
285                    .as_mut()
286                    .unwrap()
287                    .0
288                    .get_mut(input_index)
289                {
290                    conv.process(input, buffers.outputs[input_index]).unwrap();
291
292                    // Apply wet signal gain
293                    for (output_sample, gain) in buffers.outputs[input_index]
294                        .iter_mut()
295                        .zip(wet_gain_buffer.iter())
296                    {
297                        *output_sample *= gain;
298                    }
299                }
300            }
301        }
302
303        if self.impulse_response.is_some() {
304            match CHANNELS {
305                1 => {
306                    self.mix.mix_dry_into_wet_mono(
307                        buffers.inputs[0],
308                        buffers.outputs[0],
309                        info.frames,
310                    );
311                }
312                2 => {
313                    let (left, right) = buffers.outputs.split_at_mut(1);
314                    self.mix.mix_dry_into_wet_stereo(
315                        buffers.inputs[0],
316                        buffers.inputs[1],
317                        left[0],
318                        right[0],
319                        info.frames,
320                    );
321                }
322                _ => panic!("Only Mono and Stereo are supported"),
323            }
324        } else {
325            // Pass through audio if no impulse provided
326            for (input, output) in buffers.inputs.iter().zip(buffers.outputs.iter_mut()) {
327                output.copy_from_slice(input);
328            }
329        }
330
331        self.declick.process(
332            buffers.outputs,
333            0..info.frames,
334            &extra.declick_values,
335            1.0,
336            DeclickFadeCurve::EqualPower3dB,
337        );
338
339        buffers.check_for_silence_on_outputs(f32::EPSILON)
340    }
341}
342
343#[cfg(test)]
344mod tests {
345    use super::*;
346
347    // Behave as expected up to stereo
348    #[test]
349    fn mono_stereo_ok() {
350        ConvolutionNode::<1>::default().info(&ConvolutionNodeConfig::default());
351        ConvolutionNode::<2>::default().info(&ConvolutionNodeConfig::default());
352    }
353
354    // Error when 3+ channels are requested
355    #[test]
356    #[should_panic]
357    fn fail_above_stereo() {
358        ConvolutionNode::<3>::default().info(&ConvolutionNodeConfig::default());
359    }
360}