1use core::f32;
2
3use fft_convolver::FFTConvolver;
4use firewheel_core::{
5 channel_config::{ChannelConfig, ChannelCount},
6 collector::OwnedGc,
7 diff::{Diff, Patch},
8 dsp::{
9 declick::{DeclickFadeCurve, Declicker},
10 fade::FadeCurve,
11 filter::smoothing_filter::DEFAULT_SMOOTH_SECONDS,
12 mix::{Mix, MixDSP},
13 volume::Volume,
14 },
15 event::NodeEventType,
16 node::{
17 AudioNode, AudioNodeInfo, AudioNodeProcessor, ConstructProcessorContext, ProcessStatus,
18 },
19 param::smoother::{SmoothedParam, SmootherConfig},
20 sample_resource::SampleResourceF32,
21};
22
23#[derive(Patch, Diff, Debug, Clone, Copy, PartialEq)]
28#[cfg_attr(feature = "bevy", derive(bevy_ecs::prelude::Component))]
29#[cfg_attr(feature = "bevy_reflect", derive(bevy_reflect::Reflect))]
30#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
31pub struct ConvolutionNode<const CHANNELS: usize> {
32 pub pause: bool,
37
38 pub mix: Mix,
46
47 pub fade_curve: FadeCurve,
52
53 pub wet_gain: Volume,
58
59 pub smooth_seconds: f32,
64}
65
66#[derive(Debug, Clone, Copy, PartialEq, Eq)]
68#[cfg_attr(feature = "bevy", derive(bevy_ecs::prelude::Component))]
69#[cfg_attr(feature = "bevy_reflect", derive(bevy_reflect::Reflect))]
70#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
71pub struct ConvolutionNodeConfig<const CHANNELS: usize> {
72 pub max_impulse_channel_count: ChannelCount,
78
79 pub partition_size: usize,
80}
81
82pub const DEFAULT_PARTITION_SIZE: usize = 1024;
86
87pub struct ImpulseResponse(Vec<FFTConvolver<f32>>);
91
92impl ImpulseResponse {
93 pub fn new_with_partition_size(sample: impl SampleResourceF32, partition_size: usize) -> Self {
97 let num_channels = sample.num_channels().get();
98 Self(
99 (0..num_channels)
100 .map(|channel_index| {
101 let mut conv = FFTConvolver::default();
102 conv.init(partition_size, sample.channel(channel_index).unwrap())
106 .unwrap();
107 conv
108 })
109 .collect(),
110 )
111 }
112
113 pub fn new(sample: impl SampleResourceF32) -> Self {
115 Self::new_with_partition_size(sample, DEFAULT_PARTITION_SIZE)
116 }
117}
118
119impl<const CHANNELS: usize> Default for ConvolutionNodeConfig<CHANNELS> {
120 fn default() -> Self {
121 Self {
122 max_impulse_channel_count: ChannelCount::new(CHANNELS as u32).unwrap(),
124 partition_size: DEFAULT_PARTITION_SIZE,
125 }
126 }
127}
128
129impl<const CHANNELS: usize> Default for ConvolutionNode<CHANNELS> {
130 fn default() -> Self {
131 Self {
132 mix: Mix::CENTER,
133 fade_curve: FadeCurve::default(),
134 wet_gain: Volume::Decibels(-20.0),
135 pause: false,
136 smooth_seconds: DEFAULT_SMOOTH_SECONDS,
137 }
138 }
139}
140
141impl<const CHANNELS: usize> AudioNode for ConvolutionNode<CHANNELS> {
142 type Configuration = ConvolutionNodeConfig<CHANNELS>;
143
144 fn info(&self, _configuration: &Self::Configuration) -> AudioNodeInfo {
145 if CHANNELS > 2 {
146 panic!(
147 "ConvolutionNode::CHANNELS cannot be greater than 2, got {}",
148 CHANNELS
149 );
150 }
151 AudioNodeInfo::new()
152 .debug_name("convolution")
153 .channel_config(ChannelConfig::new(CHANNELS, CHANNELS))
154 }
155
156 fn construct_processor(
157 &self,
158 _configuration: &Self::Configuration,
159 cx: ConstructProcessorContext,
160 ) -> impl AudioNodeProcessor {
161 let sample_rate = cx.stream_info.sample_rate;
162 let smooth_config = SmootherConfig {
163 smooth_seconds: self.smooth_seconds,
164 ..Default::default()
165 };
166 ConvolutionProcessor::<CHANNELS> {
167 params: self.clone(),
168 mix: MixDSP::new(self.mix, self.fade_curve, smooth_config, sample_rate),
169 wet_gain_smoothed: SmoothedParam::new(self.wet_gain.amp(), smooth_config, sample_rate),
170 declick: Declicker::default(),
171 impulse_response: OwnedGc::new(None),
172 next_impulse_response: OwnedGc::new(None),
173 }
174 }
175}
176
177pub enum ConvolutionNodeEvent {
178 SetImpulseResponse(Option<ImpulseResponse>),
179}
180
181struct ConvolutionProcessor<const CHANNELS: usize> {
182 params: ConvolutionNode<CHANNELS>,
183 mix: MixDSP,
184 wet_gain_smoothed: SmoothedParam,
185 declick: Declicker,
186 impulse_response: OwnedGc<Option<ImpulseResponse>>,
187 next_impulse_response: OwnedGc<Option<ImpulseResponse>>,
191}
192
193impl<const CHANNELS: usize> AudioNodeProcessor for ConvolutionProcessor<CHANNELS> {
194 fn process(
195 &mut self,
196 info: &firewheel_core::node::ProcInfo,
197 buffers: firewheel_core::node::ProcBuffers,
198 events: &mut firewheel_core::event::ProcEvents,
199 extra: &mut firewheel_core::node::ProcExtra,
200 ) -> ProcessStatus {
201 for mut event in events.drain() {
202 match event {
203 NodeEventType::Param { data, path } => {
204 if let Ok(patch) = ConvolutionNode::<CHANNELS>::patch(&data, &path) {
205 match patch {
207 ConvolutionNodePatch::Mix(mix) => {
208 self.mix.set_mix(mix, self.params.fade_curve);
209 }
210 ConvolutionNodePatch::FadeCurve(curve) => {
211 self.mix.set_mix(self.params.mix, curve);
212 }
213 ConvolutionNodePatch::WetGain(gain) => {
214 self.wet_gain_smoothed.set_value(gain.amp());
215 }
216 ConvolutionNodePatch::Pause(pause) => {
217 self.declick.fade_to_enabled(!pause, &extra.declick_values);
218 }
219 ConvolutionNodePatch::SmoothSeconds(smooth_seconds) => {
220 self.mix = MixDSP::new(
221 self.params.mix,
222 self.params.fade_curve,
223 SmootherConfig {
224 smooth_seconds,
225 ..Default::default()
226 },
227 info.sample_rate,
228 );
229 self.wet_gain_smoothed
230 .set_smooth_seconds(smooth_seconds, info.sample_rate);
231 }
232 }
233 self.params.apply(patch);
234 }
235 }
236 NodeEventType::Custom(_) => {
237 if event.downcast_into_owned(&mut self.next_impulse_response) {
238 self.declick.fade_to_0(&extra.declick_values);
240 }
241 }
242 _ => (),
243 }
244 }
245
246 if self.next_impulse_response.is_some() && self.declick == Declicker::SettledAt0 {
249 let next_impulse_response = self.next_impulse_response.take().unwrap();
251 self.impulse_response.replace(next_impulse_response);
252 if !self.params.pause {
254 self.declick.fade_to_1(&extra.declick_values);
255 }
256 return ProcessStatus::ClearAllOutputs;
258 }
259
260 if self.impulse_response.is_some() {
262 const WET_GAIN_BUFFER: usize = 0;
263 let wet_gain_buffer = &mut extra.scratch_buffers.channels_mut::<1>()[WET_GAIN_BUFFER];
264
265 self.wet_gain_smoothed.process_into_buffer(wet_gain_buffer);
267
268 if self.params.pause && self.declick == Declicker::SettledAt0 {
271 return ProcessStatus::ClearAllOutputs;
272 }
273
274 for (input_index, input) in buffers.inputs.iter().enumerate() {
275 if let Some(conv) = self
283 .impulse_response
284 .get_mut()
285 .as_mut()
286 .unwrap()
287 .0
288 .get_mut(input_index)
289 {
290 conv.process(input, buffers.outputs[input_index]).unwrap();
291
292 for (output_sample, gain) in buffers.outputs[input_index]
294 .iter_mut()
295 .zip(wet_gain_buffer.iter())
296 {
297 *output_sample *= gain;
298 }
299 }
300 }
301 }
302
303 if self.impulse_response.is_some() {
304 match CHANNELS {
305 1 => {
306 self.mix.mix_dry_into_wet_mono(
307 buffers.inputs[0],
308 buffers.outputs[0],
309 info.frames,
310 );
311 }
312 2 => {
313 let (left, right) = buffers.outputs.split_at_mut(1);
314 self.mix.mix_dry_into_wet_stereo(
315 buffers.inputs[0],
316 buffers.inputs[1],
317 left[0],
318 right[0],
319 info.frames,
320 );
321 }
322 _ => panic!("Only Mono and Stereo are supported"),
323 }
324 } else {
325 for (input, output) in buffers.inputs.iter().zip(buffers.outputs.iter_mut()) {
327 output.copy_from_slice(input);
328 }
329 }
330
331 self.declick.process(
332 buffers.outputs,
333 0..info.frames,
334 &extra.declick_values,
335 1.0,
336 DeclickFadeCurve::EqualPower3dB,
337 );
338
339 buffers.check_for_silence_on_outputs(f32::EPSILON)
340 }
341}
342
343#[cfg(test)]
344mod tests {
345 use super::*;
346
347 #[test]
349 fn mono_stereo_ok() {
350 ConvolutionNode::<1>::default().info(&ConvolutionNodeConfig::default());
351 ConvolutionNode::<2>::default().info(&ConvolutionNodeConfig::default());
352 }
353
354 #[test]
356 #[should_panic]
357 fn fail_above_stereo() {
358 ConvolutionNode::<3>::default().info(&ConvolutionNodeConfig::default());
359 }
360}