Skip to main content

firewheel_nodes/
convolution.rs

1use core::f32;
2
3use fft_convolver::FFTConvolver;
4use firewheel_core::{
5    channel_config::{ChannelConfig, ChannelCount},
6    collector::OwnedGc,
7    diff::{Diff, Patch},
8    dsp::{
9        declick::{DeclickFadeCurve, Declicker},
10        fade::FadeCurve,
11        filter::smoothing_filter::DEFAULT_SMOOTH_SECONDS,
12        mix::{Mix, MixDSP},
13        volume::Volume,
14    },
15    event::NodeEventType,
16    node::{
17        AudioNode, AudioNodeInfo, AudioNodeProcessor, ConstructProcessorContext, ProcessStatus,
18    },
19    param::smoother::{SmoothedParam, SmootherConfig},
20    sample_resource::SampleResourceF32,
21};
22
23pub type ConvolutionMonoNode = ConvolutionNode<1>;
24pub type ConvolutionStereoNode = ConvolutionNode<2>;
25
26/// Imparts characteristics of an [`ImpulseResponse`] to the input signal.
27///
28/// Convolution is often used to achieve reverb effects, but is more
29/// computationally expensive than algorithmic reverb.
30#[derive(Patch, Diff, Debug, Clone, Copy, PartialEq)]
31#[cfg_attr(feature = "bevy", derive(bevy_ecs::prelude::Component))]
32#[cfg_attr(feature = "bevy_reflect", derive(bevy_reflect::Reflect))]
33#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
34pub struct ConvolutionNode<const CHANNELS: usize = 2> {
35    /// Pause the convolution processing.
36    ///
37    /// This prevents a tail from ringing out when you want all sound to
38    /// momentarily pause.
39    pub pause: bool,
40
41    /// The value representing the mix between the two audio signals
42    ///
43    /// This is a normalized value in the range `[0.0, 1.0]`, where `0.0` is
44    /// fully the first signal, `1.0` is fully the second signal, and `0.5` is
45    /// an equal mix of both.
46    ///
47    /// By default this is set to [`Mix::CENTER`].
48    pub mix: Mix,
49
50    /// The algorithm used to map the normalized mix value in the range `[0.0,
51    /// 1.0]` to the corresponding gain values for the two signals.
52    ///
53    /// By default this is set to [`FadeCurve::EqualPower3dB`].
54    pub fade_curve: FadeCurve,
55
56    /// The gain applied to the resulting convolved signal.
57    ///
58    /// Defaults to -20dB to balance the volume increase likely to occur when
59    /// convolving audio. Values closer to 1.0 may be very loud.
60    pub wet_gain: Volume,
61
62    /// Adjusts the time in seconds over which parameters are smoothed for `mix`
63    /// and `wet_gain`.
64    ///
65    /// Defaults to `0.015` (15ms).
66    pub smooth_seconds: f32,
67}
68
69pub type ConvolutionMonoNodeConfig = ConvolutionNodeConfig<1>;
70pub type ConvolutionStereoNodeConfig = ConvolutionNodeConfig<2>;
71
72/// Node configuration for [`ConvolutionNode`].
73#[derive(Debug, Clone, Copy, PartialEq, Eq)]
74#[cfg_attr(feature = "bevy", derive(bevy_ecs::prelude::Component))]
75#[cfg_attr(feature = "bevy_reflect", derive(bevy_reflect::Reflect))]
76#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
77pub struct ConvolutionNodeConfig<const CHANNELS: usize = 2> {
78    /// The maximum number of supported IR channels (must be
79    /// `ChannelCount::MONO` or `ChannelCount::STEREO`). This determines the
80    /// number of buffers allocated. Loading an impulse response with more
81    /// channels than supported will result in the remaining channels being
82    /// removed.
83    pub max_impulse_channel_count: ChannelCount,
84
85    pub partition_size: usize,
86}
87
88/// The default partition size to use with a [`ConvolutionNode`].
89///
90/// Smaller blocks may reduce latency at the cost of increased CPU usage.
91pub const DEFAULT_PARTITION_SIZE: usize = 1024;
92
93/// A processed impulse response sample.
94///
95/// `ImpulseResponse`s are used in [`ConvolutionNode`]s.
96pub struct ImpulseResponse(Vec<FFTConvolver<f32>>);
97
98impl ImpulseResponse {
99    /// Create a new `ImpulseResponse` with a custom partition size.
100    ///
101    /// Smaller blocks may reduce latency at the cost of increased CPU usage.
102    pub fn new_with_partition_size(sample: impl SampleResourceF32, partition_size: usize) -> Self {
103        let num_channels = sample.num_channels().get();
104        Self(
105            (0..num_channels)
106                .map(|channel_index| {
107                    let mut conv = FFTConvolver::default();
108                    // The sample channel must exist, as our iterator is based
109                    // on its length. The FFT may error, depending on several
110                    // factors. Currently, this will result in a panic.
111                    conv.init(partition_size, sample.channel(channel_index).unwrap())
112                        .unwrap();
113                    conv
114                })
115                .collect(),
116        )
117    }
118
119    /// Create a new `ImpulseResponse` with a default partition size of `1024`.
120    pub fn new(sample: impl SampleResourceF32) -> Self {
121        Self::new_with_partition_size(sample, DEFAULT_PARTITION_SIZE)
122    }
123}
124
125impl<const CHANNELS: usize> Default for ConvolutionNodeConfig<CHANNELS> {
126    fn default() -> Self {
127        Self {
128            // A Convolution node with 0 `CHANNELS` is invalid and will panic.
129            max_impulse_channel_count: ChannelCount::new(CHANNELS as u32).unwrap(),
130            partition_size: DEFAULT_PARTITION_SIZE,
131        }
132    }
133}
134
135impl<const CHANNELS: usize> Default for ConvolutionNode<CHANNELS> {
136    fn default() -> Self {
137        Self {
138            mix: Mix::CENTER,
139            fade_curve: FadeCurve::default(),
140            wet_gain: Volume::Decibels(-20.0),
141            pause: false,
142            smooth_seconds: DEFAULT_SMOOTH_SECONDS,
143        }
144    }
145}
146
147impl<const CHANNELS: usize> AudioNode for ConvolutionNode<CHANNELS> {
148    type Configuration = ConvolutionNodeConfig<CHANNELS>;
149
150    fn info(&self, _configuration: &Self::Configuration) -> AudioNodeInfo {
151        if CHANNELS > 2 {
152            panic!(
153                "ConvolutionNode::CHANNELS cannot be greater than 2, got {}",
154                CHANNELS
155            );
156        }
157        AudioNodeInfo::new()
158            .debug_name("convolution")
159            .channel_config(ChannelConfig::new(CHANNELS, CHANNELS))
160    }
161
162    fn construct_processor(
163        &self,
164        _configuration: &Self::Configuration,
165        cx: ConstructProcessorContext,
166    ) -> impl AudioNodeProcessor {
167        let sample_rate = cx.stream_info.sample_rate;
168        let smooth_config = SmootherConfig {
169            smooth_seconds: self.smooth_seconds,
170            ..Default::default()
171        };
172        ConvolutionProcessor::<CHANNELS> {
173            params: self.clone(),
174            mix: MixDSP::new(self.mix, self.fade_curve, smooth_config, sample_rate),
175            wet_gain_smoothed: SmoothedParam::new(self.wet_gain.amp(), smooth_config, sample_rate),
176            declick: Declicker::default(),
177            impulse_response: OwnedGc::new(None),
178            next_impulse_response: OwnedGc::new(None),
179        }
180    }
181}
182
183pub enum ConvolutionNodeEvent {
184    SetImpulseResponse(Option<ImpulseResponse>),
185}
186
187struct ConvolutionProcessor<const CHANNELS: usize> {
188    params: ConvolutionNode<CHANNELS>,
189    mix: MixDSP,
190    wet_gain_smoothed: SmoothedParam,
191    declick: Declicker,
192    impulse_response: OwnedGc<Option<ImpulseResponse>>,
193    // We cannot be certain that the transition to a new impulse response will
194    // happen within one block, so we must store the old impulse response until
195    // the declicker settles.
196    next_impulse_response: OwnedGc<Option<ImpulseResponse>>,
197}
198
199impl<const CHANNELS: usize> AudioNodeProcessor for ConvolutionProcessor<CHANNELS> {
200    fn process(
201        &mut self,
202        info: &firewheel_core::node::ProcInfo,
203        buffers: firewheel_core::node::ProcBuffers,
204        events: &mut firewheel_core::event::ProcEvents,
205        extra: &mut firewheel_core::node::ProcExtra,
206    ) -> ProcessStatus {
207        for mut event in events.drain() {
208            match event {
209                NodeEventType::Param { data, path } => {
210                    if let Ok(patch) = ConvolutionNode::<CHANNELS>::patch(&data, &path) {
211                        // You can match on the patch directly
212                        match patch {
213                            ConvolutionNodePatch::Mix(mix) => {
214                                self.mix.set_mix(mix, self.params.fade_curve);
215                            }
216                            ConvolutionNodePatch::FadeCurve(curve) => {
217                                self.mix.set_mix(self.params.mix, curve);
218                            }
219                            ConvolutionNodePatch::WetGain(gain) => {
220                                self.wet_gain_smoothed.set_value(gain.amp());
221                            }
222                            ConvolutionNodePatch::Pause(pause) => {
223                                self.declick.fade_to_enabled(!pause, &extra.declick_values);
224                            }
225                            ConvolutionNodePatch::SmoothSeconds(smooth_seconds) => {
226                                self.mix = MixDSP::new(
227                                    self.params.mix,
228                                    self.params.fade_curve,
229                                    SmootherConfig {
230                                        smooth_seconds,
231                                        ..Default::default()
232                                    },
233                                    info.sample_rate,
234                                );
235                                self.wet_gain_smoothed
236                                    .set_smooth_seconds(smooth_seconds, info.sample_rate);
237                            }
238                        }
239                        self.params.apply(patch);
240                    }
241                }
242                NodeEventType::Custom(_) => {
243                    if event.downcast_into_owned(&mut self.next_impulse_response) {
244                        // Disable the audio stream while changing IRs
245                        self.declick.fade_to_0(&extra.declick_values);
246                    }
247                }
248                _ => (),
249            }
250        }
251
252        // Check to see if there is a new IR waiting. If there is, and the audio
253        // has stopped, swap the IR, and continue
254        if self.next_impulse_response.is_some() && self.declick == Declicker::SettledAt0 {
255            // The next impulse result must exist due to the check in this block
256            let next_impulse_response = self.next_impulse_response.take().unwrap();
257            self.impulse_response.replace(next_impulse_response);
258            // Don't unpause if we're paused manually
259            if !self.params.pause {
260                self.declick.fade_to_1(&extra.declick_values);
261            }
262            // Begin mixing back in with the new impulse response next block
263            return ProcessStatus::ClearAllOutputs;
264        }
265
266        // Only process if an impulse response is supplied
267        if self.impulse_response.is_some() {
268            const WET_GAIN_BUFFER: usize = 0;
269            let wet_gain_buffer = &mut extra.scratch_buffers.channels_mut::<1>()[WET_GAIN_BUFFER];
270
271            // Amount to scale based on wet signal gain
272            self.wet_gain_smoothed.process_into_buffer(wet_gain_buffer);
273
274            // If paused, return early after processing wet gain buffers to
275            // avoid clicking
276            if self.params.pause && self.declick == Declicker::SettledAt0 {
277                return ProcessStatus::ClearAllOutputs;
278            }
279
280            for (input_index, input) in buffers.inputs.iter().enumerate() {
281                // We unfortunately can't add more buffers to the convolution
282                // struct, as we don't own it. This means we can't do stereo
283                // with a mono impulse response. In this case, we'll just pass
284                // the input through if we can't get a channel.
285
286                // We already checked that the impulse response must exist, so
287                // we can safely unwrap.
288                if let Some(conv) = self
289                    .impulse_response
290                    .get_mut()
291                    .as_mut()
292                    .unwrap()
293                    .0
294                    .get_mut(input_index)
295                {
296                    conv.process(input, buffers.outputs[input_index]).unwrap();
297
298                    // Apply wet signal gain
299                    for (output_sample, gain) in buffers.outputs[input_index]
300                        .iter_mut()
301                        .zip(wet_gain_buffer.iter())
302                    {
303                        *output_sample *= gain;
304                    }
305                }
306            }
307        }
308
309        if self.impulse_response.is_some() {
310            match CHANNELS {
311                1 => {
312                    self.mix.mix_dry_into_wet_mono(
313                        buffers.inputs[0],
314                        buffers.outputs[0],
315                        info.frames,
316                    );
317                }
318                2 => {
319                    let (left, right) = buffers.outputs.split_at_mut(1);
320                    self.mix.mix_dry_into_wet_stereo(
321                        buffers.inputs[0],
322                        buffers.inputs[1],
323                        left[0],
324                        right[0],
325                        info.frames,
326                    );
327                }
328                _ => panic!("Only Mono and Stereo are supported"),
329            }
330        } else {
331            // Pass through audio if no impulse provided
332            for (input, output) in buffers.inputs.iter().zip(buffers.outputs.iter_mut()) {
333                output.copy_from_slice(input);
334            }
335        }
336
337        self.declick.process(
338            buffers.outputs,
339            0..info.frames,
340            &extra.declick_values,
341            1.0,
342            DeclickFadeCurve::EqualPower3dB,
343        );
344
345        buffers.check_for_silence_on_outputs(f32::EPSILON)
346    }
347}
348
349#[cfg(test)]
350mod tests {
351    use super::*;
352
353    // Behave as expected up to stereo
354    #[test]
355    fn mono_stereo_ok() {
356        ConvolutionNode::<1>::default().info(&ConvolutionNodeConfig::default());
357        ConvolutionNode::<2>::default().info(&ConvolutionNodeConfig::default());
358    }
359
360    // Error when 3+ channels are requested
361    #[test]
362    #[should_panic]
363    fn fail_above_stereo() {
364        ConvolutionNode::<3>::default().info(&ConvolutionNodeConfig::default());
365    }
366}