1use core::f32;
2
3use fft_convolver::FFTConvolver;
4use firewheel_core::{
5 channel_config::{ChannelConfig, ChannelCount},
6 collector::OwnedGc,
7 diff::{Diff, Patch},
8 dsp::{
9 declick::{DeclickFadeCurve, Declicker},
10 fade::FadeCurve,
11 filter::smoothing_filter::DEFAULT_SMOOTH_SECONDS,
12 mix::{Mix, MixDSP},
13 volume::Volume,
14 },
15 event::NodeEventType,
16 node::{
17 AudioNode, AudioNodeInfo, AudioNodeProcessor, ConstructProcessorContext, ProcessStatus,
18 },
19 param::smoother::{SmoothedParam, SmootherConfig},
20 sample_resource::SampleResourceF32,
21};
22
23pub type ConvolutionMonoNode = ConvolutionNode<1>;
24pub type ConvolutionStereoNode = ConvolutionNode<2>;
25
26#[derive(Patch, Diff, Debug, Clone, Copy, PartialEq)]
31#[cfg_attr(feature = "bevy", derive(bevy_ecs::prelude::Component))]
32#[cfg_attr(feature = "bevy_reflect", derive(bevy_reflect::Reflect))]
33#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
34pub struct ConvolutionNode<const CHANNELS: usize = 2> {
35 pub pause: bool,
40
41 pub mix: Mix,
49
50 pub fade_curve: FadeCurve,
55
56 pub wet_gain: Volume,
61
62 pub smooth_seconds: f32,
67}
68
69pub type ConvolutionMonoNodeConfig = ConvolutionNodeConfig<1>;
70pub type ConvolutionStereoNodeConfig = ConvolutionNodeConfig<2>;
71
72#[derive(Debug, Clone, Copy, PartialEq, Eq)]
74#[cfg_attr(feature = "bevy", derive(bevy_ecs::prelude::Component))]
75#[cfg_attr(feature = "bevy_reflect", derive(bevy_reflect::Reflect))]
76#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
77pub struct ConvolutionNodeConfig<const CHANNELS: usize = 2> {
78 pub max_impulse_channel_count: ChannelCount,
84
85 pub partition_size: usize,
86}
87
88pub const DEFAULT_PARTITION_SIZE: usize = 1024;
92
93pub struct ImpulseResponse(Vec<FFTConvolver<f32>>);
97
98impl ImpulseResponse {
99 pub fn new_with_partition_size(sample: impl SampleResourceF32, partition_size: usize) -> Self {
103 let num_channels = sample.num_channels().get();
104 Self(
105 (0..num_channels)
106 .map(|channel_index| {
107 let mut conv = FFTConvolver::default();
108 conv.init(partition_size, sample.channel(channel_index).unwrap())
112 .unwrap();
113 conv
114 })
115 .collect(),
116 )
117 }
118
119 pub fn new(sample: impl SampleResourceF32) -> Self {
121 Self::new_with_partition_size(sample, DEFAULT_PARTITION_SIZE)
122 }
123}
124
125impl<const CHANNELS: usize> Default for ConvolutionNodeConfig<CHANNELS> {
126 fn default() -> Self {
127 Self {
128 max_impulse_channel_count: ChannelCount::new(CHANNELS as u32).unwrap(),
130 partition_size: DEFAULT_PARTITION_SIZE,
131 }
132 }
133}
134
135impl<const CHANNELS: usize> Default for ConvolutionNode<CHANNELS> {
136 fn default() -> Self {
137 Self {
138 mix: Mix::CENTER,
139 fade_curve: FadeCurve::default(),
140 wet_gain: Volume::Decibels(-20.0),
141 pause: false,
142 smooth_seconds: DEFAULT_SMOOTH_SECONDS,
143 }
144 }
145}
146
147impl<const CHANNELS: usize> AudioNode for ConvolutionNode<CHANNELS> {
148 type Configuration = ConvolutionNodeConfig<CHANNELS>;
149
150 fn info(&self, _configuration: &Self::Configuration) -> AudioNodeInfo {
151 if CHANNELS > 2 {
152 panic!(
153 "ConvolutionNode::CHANNELS cannot be greater than 2, got {}",
154 CHANNELS
155 );
156 }
157 AudioNodeInfo::new()
158 .debug_name("convolution")
159 .channel_config(ChannelConfig::new(CHANNELS, CHANNELS))
160 }
161
162 fn construct_processor(
163 &self,
164 _configuration: &Self::Configuration,
165 cx: ConstructProcessorContext,
166 ) -> impl AudioNodeProcessor {
167 let sample_rate = cx.stream_info.sample_rate;
168 let smooth_config = SmootherConfig {
169 smooth_seconds: self.smooth_seconds,
170 ..Default::default()
171 };
172 ConvolutionProcessor::<CHANNELS> {
173 params: self.clone(),
174 mix: MixDSP::new(self.mix, self.fade_curve, smooth_config, sample_rate),
175 wet_gain_smoothed: SmoothedParam::new(self.wet_gain.amp(), smooth_config, sample_rate),
176 declick: Declicker::default(),
177 impulse_response: OwnedGc::new(None),
178 next_impulse_response: OwnedGc::new(None),
179 }
180 }
181}
182
183pub enum ConvolutionNodeEvent {
184 SetImpulseResponse(Option<ImpulseResponse>),
185}
186
187struct ConvolutionProcessor<const CHANNELS: usize> {
188 params: ConvolutionNode<CHANNELS>,
189 mix: MixDSP,
190 wet_gain_smoothed: SmoothedParam,
191 declick: Declicker,
192 impulse_response: OwnedGc<Option<ImpulseResponse>>,
193 next_impulse_response: OwnedGc<Option<ImpulseResponse>>,
197}
198
199impl<const CHANNELS: usize> AudioNodeProcessor for ConvolutionProcessor<CHANNELS> {
200 fn process(
201 &mut self,
202 info: &firewheel_core::node::ProcInfo,
203 buffers: firewheel_core::node::ProcBuffers,
204 events: &mut firewheel_core::event::ProcEvents,
205 extra: &mut firewheel_core::node::ProcExtra,
206 ) -> ProcessStatus {
207 for mut event in events.drain() {
208 match event {
209 NodeEventType::Param { data, path } => {
210 if let Ok(patch) = ConvolutionNode::<CHANNELS>::patch(&data, &path) {
211 match patch {
213 ConvolutionNodePatch::Mix(mix) => {
214 self.mix.set_mix(mix, self.params.fade_curve);
215 }
216 ConvolutionNodePatch::FadeCurve(curve) => {
217 self.mix.set_mix(self.params.mix, curve);
218 }
219 ConvolutionNodePatch::WetGain(gain) => {
220 self.wet_gain_smoothed.set_value(gain.amp());
221 }
222 ConvolutionNodePatch::Pause(pause) => {
223 self.declick.fade_to_enabled(!pause, &extra.declick_values);
224 }
225 ConvolutionNodePatch::SmoothSeconds(smooth_seconds) => {
226 self.mix = MixDSP::new(
227 self.params.mix,
228 self.params.fade_curve,
229 SmootherConfig {
230 smooth_seconds,
231 ..Default::default()
232 },
233 info.sample_rate,
234 );
235 self.wet_gain_smoothed
236 .set_smooth_seconds(smooth_seconds, info.sample_rate);
237 }
238 }
239 self.params.apply(patch);
240 }
241 }
242 NodeEventType::Custom(_) => {
243 if event.downcast_into_owned(&mut self.next_impulse_response) {
244 self.declick.fade_to_0(&extra.declick_values);
246 }
247 }
248 _ => (),
249 }
250 }
251
252 if self.next_impulse_response.is_some() && self.declick == Declicker::SettledAt0 {
255 let next_impulse_response = self.next_impulse_response.take().unwrap();
257 self.impulse_response.replace(next_impulse_response);
258 if !self.params.pause {
260 self.declick.fade_to_1(&extra.declick_values);
261 }
262 return ProcessStatus::ClearAllOutputs;
264 }
265
266 if self.impulse_response.is_some() {
268 const WET_GAIN_BUFFER: usize = 0;
269 let wet_gain_buffer = &mut extra.scratch_buffers.channels_mut::<1>()[WET_GAIN_BUFFER];
270
271 self.wet_gain_smoothed.process_into_buffer(wet_gain_buffer);
273
274 if self.params.pause && self.declick == Declicker::SettledAt0 {
277 return ProcessStatus::ClearAllOutputs;
278 }
279
280 for (input_index, input) in buffers.inputs.iter().enumerate() {
281 if let Some(conv) = self
289 .impulse_response
290 .get_mut()
291 .as_mut()
292 .unwrap()
293 .0
294 .get_mut(input_index)
295 {
296 conv.process(input, buffers.outputs[input_index]).unwrap();
297
298 for (output_sample, gain) in buffers.outputs[input_index]
300 .iter_mut()
301 .zip(wet_gain_buffer.iter())
302 {
303 *output_sample *= gain;
304 }
305 }
306 }
307 }
308
309 if self.impulse_response.is_some() {
310 match CHANNELS {
311 1 => {
312 self.mix.mix_dry_into_wet_mono(
313 buffers.inputs[0],
314 buffers.outputs[0],
315 info.frames,
316 );
317 }
318 2 => {
319 let (left, right) = buffers.outputs.split_at_mut(1);
320 self.mix.mix_dry_into_wet_stereo(
321 buffers.inputs[0],
322 buffers.inputs[1],
323 left[0],
324 right[0],
325 info.frames,
326 );
327 }
328 _ => panic!("Only Mono and Stereo are supported"),
329 }
330 } else {
331 for (input, output) in buffers.inputs.iter().zip(buffers.outputs.iter_mut()) {
333 output.copy_from_slice(input);
334 }
335 }
336
337 self.declick.process(
338 buffers.outputs,
339 0..info.frames,
340 &extra.declick_values,
341 1.0,
342 DeclickFadeCurve::EqualPower3dB,
343 );
344
345 buffers.check_for_silence_on_outputs(f32::EPSILON)
346 }
347}
348
349#[cfg(test)]
350mod tests {
351 use super::*;
352
353 #[test]
355 fn mono_stereo_ok() {
356 ConvolutionNode::<1>::default().info(&ConvolutionNodeConfig::default());
357 ConvolutionNode::<2>::default().info(&ConvolutionNodeConfig::default());
358 }
359
360 #[test]
362 #[should_panic]
363 fn fail_above_stereo() {
364 ConvolutionNode::<3>::default().info(&ConvolutionNodeConfig::default());
365 }
366}