1use core::f32;
2
3use fft_convolver::FFTConvolver;
4use firewheel_core::{
5 channel_config::{ChannelConfig, ChannelCount},
6 collector::OwnedGc,
7 diff::{Diff, Patch},
8 dsp::{
9 declick::{DeclickFadeCurve, Declicker},
10 fade::FadeCurve,
11 filter::smoothing_filter::DEFAULT_SMOOTH_SECONDS,
12 mix::{Mix, MixDSP},
13 volume::Volume,
14 },
15 event::NodeEventType,
16 node::{
17 AudioNode, AudioNodeInfo, AudioNodeProcessor, ConstructProcessorContext, ProcessStatus,
18 },
19 param::smoother::{SmoothedParam, SmootherConfig},
20 sample_resource::SampleResourceF32,
21};
22
23#[derive(Clone, Copy, Patch, Diff, PartialEq)]
28#[cfg_attr(feature = "bevy", derive(bevy_ecs::prelude::Component))]
29pub struct ConvolutionNode<const CHANNELS: usize> {
30 pub pause: bool,
35
36 pub mix: Mix,
44
45 pub fade_curve: FadeCurve,
50
51 pub wet_gain: Volume,
56
57 pub smooth_seconds: f32,
62}
63
64#[derive(Debug, Clone, Copy, PartialEq, Eq)]
66#[cfg_attr(feature = "bevy", derive(bevy_ecs::prelude::Component))]
67#[cfg_attr(feature = "bevy_reflect", derive(bevy_reflect::Reflect))]
68pub struct ConvolutionNodeConfig<const CHANNELS: usize> {
69 pub max_impulse_channel_count: ChannelCount,
75
76 pub partition_size: usize,
77}
78
79pub const DEFAULT_PARTITION_SIZE: usize = 1024;
83
84pub struct ImpulseResponse(Vec<FFTConvolver<f32>>);
88
89impl ImpulseResponse {
90 pub fn new_with_partition_size(sample: impl SampleResourceF32, partition_size: usize) -> Self {
94 let num_channels = sample.num_channels().get();
95 Self(
96 (0..num_channels)
97 .map(|channel_index| {
98 let mut conv = FFTConvolver::default();
99 conv.init(partition_size, sample.channel(channel_index).unwrap())
103 .unwrap();
104 conv
105 })
106 .collect(),
107 )
108 }
109
110 pub fn new(sample: impl SampleResourceF32) -> Self {
112 Self::new_with_partition_size(sample, DEFAULT_PARTITION_SIZE)
113 }
114}
115
116impl<const CHANNELS: usize> Default for ConvolutionNodeConfig<CHANNELS> {
117 fn default() -> Self {
118 Self {
119 max_impulse_channel_count: ChannelCount::new(CHANNELS as u32).unwrap(),
121 partition_size: DEFAULT_PARTITION_SIZE,
122 }
123 }
124}
125
126impl<const CHANNELS: usize> Default for ConvolutionNode<CHANNELS> {
127 fn default() -> Self {
128 Self {
129 mix: Mix::CENTER,
130 fade_curve: FadeCurve::default(),
131 wet_gain: Volume::Decibels(-20.0),
132 pause: false,
133 smooth_seconds: DEFAULT_SMOOTH_SECONDS,
134 }
135 }
136}
137
138impl<const CHANNELS: usize> AudioNode for ConvolutionNode<CHANNELS> {
139 type Configuration = ConvolutionNodeConfig<CHANNELS>;
140
141 fn info(&self, _configuration: &Self::Configuration) -> AudioNodeInfo {
142 if CHANNELS > 2 {
143 panic!(
144 "ConvolutionNode::CHANNELS cannot be greater than 2, got {}",
145 CHANNELS
146 );
147 }
148 AudioNodeInfo::new()
149 .debug_name("convolution")
150 .channel_config(ChannelConfig::new(CHANNELS, CHANNELS))
151 }
152
153 fn construct_processor(
154 &self,
155 _configuration: &Self::Configuration,
156 cx: ConstructProcessorContext,
157 ) -> impl AudioNodeProcessor {
158 let sample_rate = cx.stream_info.sample_rate;
159 let smooth_config = SmootherConfig {
160 smooth_seconds: self.smooth_seconds,
161 ..Default::default()
162 };
163 ConvolutionProcessor::<CHANNELS> {
164 params: self.clone(),
165 mix: MixDSP::new(self.mix, self.fade_curve, smooth_config, sample_rate),
166 wet_gain_smoothed: SmoothedParam::new(self.wet_gain.amp(), smooth_config, sample_rate),
167 declick: Declicker::default(),
168 impulse_response: OwnedGc::new(None),
169 next_impulse_response: OwnedGc::new(None),
170 }
171 }
172}
173
174pub enum ConvolutionNodeEvent {
175 SetImpulseResponse(Option<ImpulseResponse>),
176}
177
178struct ConvolutionProcessor<const CHANNELS: usize> {
179 params: ConvolutionNode<CHANNELS>,
180 mix: MixDSP,
181 wet_gain_smoothed: SmoothedParam,
182 declick: Declicker,
183 impulse_response: OwnedGc<Option<ImpulseResponse>>,
184 next_impulse_response: OwnedGc<Option<ImpulseResponse>>,
188}
189
190impl<const CHANNELS: usize> AudioNodeProcessor for ConvolutionProcessor<CHANNELS> {
191 fn process(
192 &mut self,
193 info: &firewheel_core::node::ProcInfo,
194 buffers: firewheel_core::node::ProcBuffers,
195 events: &mut firewheel_core::event::ProcEvents,
196 extra: &mut firewheel_core::node::ProcExtra,
197 ) -> ProcessStatus {
198 for mut event in events.drain() {
199 match event {
200 NodeEventType::Param { data, path } => {
201 if let Ok(patch) = ConvolutionNode::<CHANNELS>::patch(&data, &path) {
202 match patch {
204 ConvolutionNodePatch::Mix(mix) => {
205 self.mix.set_mix(mix, self.params.fade_curve);
206 }
207 ConvolutionNodePatch::FadeCurve(curve) => {
208 self.mix.set_mix(self.params.mix, curve);
209 }
210 ConvolutionNodePatch::WetGain(gain) => {
211 self.wet_gain_smoothed.set_value(gain.amp());
212 }
213 ConvolutionNodePatch::Pause(pause) => {
214 self.declick.fade_to_enabled(!pause, &extra.declick_values);
215 }
216 ConvolutionNodePatch::SmoothSeconds(smooth_seconds) => {
217 self.mix = MixDSP::new(
218 self.params.mix,
219 self.params.fade_curve,
220 SmootherConfig {
221 smooth_seconds,
222 ..Default::default()
223 },
224 info.sample_rate,
225 );
226 self.wet_gain_smoothed
227 .set_smooth_seconds(smooth_seconds, info.sample_rate);
228 }
229 }
230 self.params.apply(patch);
231 }
232 }
233 NodeEventType::Custom(_) => {
234 if event.downcast_into_owned(&mut self.next_impulse_response) {
235 self.declick.fade_to_0(&extra.declick_values);
237 }
238 }
239 _ => (),
240 }
241 }
242
243 if self.next_impulse_response.is_some() && self.declick == Declicker::SettledAt0 {
246 let next_impulse_response = self.next_impulse_response.take().unwrap();
248 self.impulse_response.replace(next_impulse_response);
249 if !self.params.pause {
251 self.declick.fade_to_1(&extra.declick_values);
252 }
253 return ProcessStatus::ClearAllOutputs;
255 }
256
257 if self.impulse_response.is_some() {
259 const WET_GAIN_BUFFER: usize = 0;
260 let wet_gain_buffer = &mut extra.scratch_buffers.channels_mut::<1>()[WET_GAIN_BUFFER];
261
262 self.wet_gain_smoothed.process_into_buffer(wet_gain_buffer);
264
265 if self.params.pause && self.declick == Declicker::SettledAt0 {
268 return ProcessStatus::ClearAllOutputs;
269 }
270
271 for (input_index, input) in buffers.inputs.iter().enumerate() {
272 if let Some(conv) = self
280 .impulse_response
281 .get_mut()
282 .as_mut()
283 .unwrap()
284 .0
285 .get_mut(input_index)
286 {
287 conv.process(input, buffers.outputs[input_index]).unwrap();
288
289 for (output_sample, gain) in buffers.outputs[input_index]
291 .iter_mut()
292 .zip(wet_gain_buffer.iter())
293 {
294 *output_sample *= gain;
295 }
296 }
297 }
298 }
299
300 if self.impulse_response.is_some() {
301 match CHANNELS {
302 1 => {
303 self.mix.mix_dry_into_wet_mono(
304 buffers.inputs[0],
305 buffers.outputs[0],
306 info.frames,
307 );
308 }
309 2 => {
310 let (left, right) = buffers.outputs.split_at_mut(1);
311 self.mix.mix_dry_into_wet_stereo(
312 buffers.inputs[0],
313 buffers.inputs[1],
314 left[0],
315 right[0],
316 info.frames,
317 );
318 }
319 _ => panic!("Only Mono and Stereo are supported"),
320 }
321 } else {
322 for (input, output) in buffers.inputs.iter().zip(buffers.outputs.iter_mut()) {
324 output.copy_from_slice(input);
325 }
326 }
327
328 self.declick.process(
329 buffers.outputs,
330 0..info.frames,
331 &extra.declick_values,
332 1.0,
333 DeclickFadeCurve::EqualPower3dB,
334 );
335
336 buffers.check_for_silence_on_outputs(f32::EPSILON)
337 }
338}
339
340#[cfg(test)]
341mod tests {
342 use super::*;
343
344 #[test]
346 fn mono_stereo_ok() {
347 ConvolutionNode::<1>::default().info(&ConvolutionNodeConfig::default());
348 ConvolutionNode::<2>::default().info(&ConvolutionNodeConfig::default());
349 }
350
351 #[test]
353 #[should_panic]
354 fn fail_above_stereo() {
355 ConvolutionNode::<3>::default().info(&ConvolutionNodeConfig::default());
356 }
357}