Skip to main content

firewheel_ircam_hrtf/
lib.rs

1//! A head-related transfer function (HRTF) node for
2//! [Firewheel](https://github.com/BillyDM/Firewheel),
3//! powered by [Fyrox](https://docs.rs/hrtf/latest/hrtf/)'s
4//! [IRCAM](http://recherche.ircam.fr/equipes/salles/listen/download.html)-based HRIR.
5//!
6//! HRTFs can provide far more convincing spatialization compared to
7//! simpler techniques. They simulate the way our bodies filter sounds
8//! based on where they're coming from, allowing you to distinguish up/down,
9//! front/back, and the more typical left/right.
10//!
11//! This simulation is moderately expensive. You'll generally want to avoid more
12//! than 32-64 HRTF emitters, especially on less powerful devices.
13
14#![cfg_attr(docsrs, feature(doc_cfg))]
15#![warn(missing_debug_implementations)]
16#![warn(missing_docs)]
17
18use firewheel::{
19    channel_config::{ChannelConfig, NonZeroChannelCount},
20    diff::{Diff, Patch},
21    dsp::{coeff_update::CoeffUpdateFactor, distance_attenuation::DistanceAttenuatorStereoDsp},
22    event::ProcEvents,
23    node::{
24        AudioNode, AudioNodeInfo, AudioNodeProcessor, ProcBuffers, ProcExtra, ProcInfo,
25        ProcessStatus,
26    },
27};
28use glam::Vec3;
29use hrtf::{HrirSphere, HrtfContext, HrtfProcessor};
30use std::io::Cursor;
31
32mod subjects;
33
34pub use firewheel::dsp::distance_attenuation::{DistanceAttenuation, DistanceModel};
35pub use subjects::{Subject, SubjectBytes};
36
37/// Head-related transfer function (HRTF) node.
38///
39/// HRTFs can provide far more convincing spatialization
40/// compared to simpler techniques. They simulate the way
41/// our bodies filter sounds based on where they’re coming from,
42/// allowing you to distinguish up/down, front/back,
43/// and the more typical left/right.
44///
45/// This simulation is moderately expensive. You’ll generally
46/// want to avoid more than 32-64 HRTF emitters, especially on
47/// less powerful devices.
48#[derive(Debug, Clone, Diff, Patch)]
49#[cfg_attr(feature = "bevy", derive(bevy_ecs::component::Component))]
50#[cfg_attr(feature = "bevy_reflect", derive(bevy_reflect::Reflect))]
51pub struct HrtfNode {
52    /// The positional offset from the listener to the emitter.
53    pub offset: Vec3,
54
55    /// The amount of muffling (lowpass) in the range `[20.0, 20_480.0]`,
56    /// where `20_480.0` is no muffling and `20.0` is maximum muffling.
57    ///
58    /// This can be used to give the effect of a sound being played behind a wall
59    /// or underwater.
60    ///
61    /// By default this is set to `20_480.0`.
62    ///
63    /// See <https://www.desmos.com/calculator/jxp8t9ero4> for an interactive graph of
64    /// how these parameters affect the final lowpass cuttoff frequency.
65    pub muffle_cutoff_hz: f32,
66
67    /// Distance attenuation parameters.
68    pub distance_attenuation: DistanceAttenuation,
69
70    /// The time in seconds of the internal smoothing filter.
71    ///
72    /// By default this is set to `0.015` (15ms).
73    pub smooth_seconds: f32,
74
75    /// If the resutling gain (in raw amplitude, not decibels) is less than or equal
76    /// to this value, the the gain will be clamped to `0` (silence).
77    ///
78    /// By default this is set to "0.0001" (-80 dB).
79    pub min_gain: f32,
80
81    /// An exponent representing the rate at which DSP coefficients are
82    /// updated when parameters are being smoothed.
83    ///
84    /// Smaller values will produce less "stair-stepping" artifacts,
85    /// but will also consume more CPU.
86    ///
87    /// The resulting number of frames (samples in a single channel of audio)
88    /// that will elapse between each update is calculated as
89    /// `2^coeff_update_factor`.
90    ///
91    /// By default this is set to `5`.
92    pub coeff_update_factor: CoeffUpdateFactor,
93}
94
95impl Default for HrtfNode {
96    fn default() -> Self {
97        Self {
98            offset: Vec3::ZERO,
99            muffle_cutoff_hz: 20480.0,
100            distance_attenuation: Default::default(),
101            smooth_seconds: 0.015,
102            min_gain: 0.0001,
103            coeff_update_factor: CoeffUpdateFactor(5),
104        }
105    }
106}
107
108/// Configuration for [`HrtfNode`].
109#[derive(Debug, Clone, PartialEq)]
110#[cfg_attr(feature = "bevy", derive(bevy_ecs::component::Component))]
111#[cfg_attr(feature = "bevy_reflect", derive(bevy_reflect::Reflect))]
112pub struct HrtfConfig {
113    /// The number of input channels.
114    ///
115    /// The inputs are downmixed to a mono signal
116    /// before spatialization is applied.
117    ///
118    /// Defaults to [`NonZeroChannelCount::STEREO`].
119    pub input_channels: NonZeroChannelCount,
120
121    /// The head-related impulse-response sphere.
122    ///
123    /// The data for this sphere is captured from subjects. Short
124    /// "impulses" are played from all angles and recorded at the
125    /// ear canal. The resulting recordings capture how sounds are affected
126    /// by the subject's torso, head, and ears.
127    ///
128    /// Defaults to `HrirSource::Embedded(Subject::Irc1040)`.
129    pub hrir_sphere: HrirSource,
130
131    /// The size of the FFT processing block, which can be
132    /// tuned for performance.
133    pub fft_size: FftSize,
134}
135
136impl Default for HrtfConfig {
137    fn default() -> Self {
138        Self {
139            input_channels: NonZeroChannelCount::STEREO,
140            hrir_sphere: Subject::Irc1040.into(),
141            fft_size: FftSize::default(),
142        }
143    }
144}
145
146/// Describes the size of the FFT processing block.
147///
148/// Generally, you should try to match the FFT size (the product of
149/// [`slice_count`][FftSize::slice_count] and [`slice_len`][FftSize::slice_len])
150/// to the audio's processing buffer size if possible.
151#[derive(Debug, Clone, PartialEq)]
152#[cfg_attr(feature = "bevy_reflect", derive(bevy_reflect::Reflect))]
153pub struct FftSize {
154    /// The number of slices the audio stream is split into for overlap-save.
155    ///
156    /// Defaults to 4.
157    pub slice_count: usize,
158
159    /// The size of each slice.
160    ///
161    /// Defaults to 128.
162    pub slice_len: usize,
163}
164
165impl Default for FftSize {
166    fn default() -> Self {
167        Self {
168            slice_count: 4,
169            slice_len: 128,
170        }
171    }
172}
173
174/// Provides a source for the HRIR sphere data.
175#[derive(Debug, Clone, PartialEq)]
176#[cfg_attr(feature = "bevy_reflect", derive(bevy_reflect::Reflect))]
177pub enum HrirSource {
178    /// Load data from the subjects embedded in the binary itself.
179    Embedded(Subject),
180    /// Load arbitrary data from an in-memory slice.
181    InMemory(SubjectBytes),
182}
183
184impl HrirSource {
185    fn get_sphere(&self, sample_rate: u32) -> Result<HrirSphere, hrtf::HrtfError> {
186        match &self {
187            HrirSource::Embedded(subject) => HrirSphere::new(Cursor::new(*subject), sample_rate),
188            HrirSource::InMemory(subject) => {
189                HrirSphere::new(Cursor::new(subject.clone()), sample_rate)
190            }
191        }
192    }
193}
194
195impl From<Subject> for HrirSource {
196    fn from(value: Subject) -> Self {
197        Self::Embedded(value)
198    }
199}
200
201impl From<SubjectBytes> for HrirSource {
202    fn from(value: SubjectBytes) -> Self {
203        Self::InMemory(value)
204    }
205}
206
207impl AudioNode for HrtfNode {
208    type Configuration = HrtfConfig;
209
210    fn info(&self, config: &Self::Configuration) -> AudioNodeInfo {
211        AudioNodeInfo::new()
212            .debug_name("hrtf node")
213            .channel_config(ChannelConfig::new(config.input_channels.get(), 2))
214    }
215
216    fn construct_processor(
217        &self,
218        config: &Self::Configuration,
219        cx: firewheel::node::ConstructProcessorContext,
220    ) -> impl firewheel::node::AudioNodeProcessor {
221        let sample_rate = cx.stream_info.sample_rate.get();
222
223        let sphere = config
224            .hrir_sphere
225            .get_sphere(sample_rate)
226            .expect("HRIR data should be in a valid format");
227
228        let fft_buffer_len = config.fft_size.slice_count * config.fft_size.slice_len;
229
230        let renderer = HrtfProcessor::new(
231            sphere,
232            config.fft_size.slice_count,
233            config.fft_size.slice_len,
234        );
235
236        let buffer_size = cx.stream_info.max_block_frames.get() as usize;
237        FyroxHrtfProcessor {
238            renderer,
239            attenuation: self.distance_attenuation,
240            attenuation_processor: DistanceAttenuatorStereoDsp::new(
241                firewheel::param::smoother::SmootherConfig {
242                    smooth_seconds: self.smooth_seconds,
243                    ..Default::default()
244                },
245                cx.stream_info.sample_rate,
246                self.coeff_update_factor,
247            ),
248            muffle_cutoff_hz: self.muffle_cutoff_hz,
249            offset: self.offset,
250            min_gain: self.min_gain,
251            fft_input: Vec::with_capacity(fft_buffer_len),
252            fft_output: Vec::with_capacity(buffer_size.max(fft_buffer_len)),
253            prev_left_samples: Vec::with_capacity(fft_buffer_len),
254            prev_right_samples: Vec::with_capacity(fft_buffer_len),
255            sphere_source: config.hrir_sphere.clone(),
256            fft_size: config.fft_size.clone(),
257        }
258    }
259}
260
261struct FyroxHrtfProcessor {
262    renderer: HrtfProcessor,
263    offset: Vec3,
264    attenuation: DistanceAttenuation,
265    attenuation_processor: DistanceAttenuatorStereoDsp,
266    muffle_cutoff_hz: f32,
267    min_gain: f32,
268    fft_input: Vec<f32>,
269    fft_output: Vec<(f32, f32)>,
270    prev_left_samples: Vec<f32>,
271    prev_right_samples: Vec<f32>,
272    sphere_source: HrirSource,
273    fft_size: FftSize,
274}
275
276impl AudioNodeProcessor for FyroxHrtfProcessor {
277    fn process(
278        &mut self,
279        proc_info: &ProcInfo,
280        ProcBuffers { inputs, outputs }: ProcBuffers,
281        events: &mut ProcEvents,
282        _: &mut ProcExtra,
283    ) -> ProcessStatus {
284        let mut previous_vector = self.offset;
285
286        for patch in events.drain_patches::<HrtfNode>() {
287            match patch {
288                HrtfNodePatch::Offset(offset) => {
289                    let distance = offset.length().max(0.01);
290
291                    self.attenuation_processor.compute_values(
292                        distance,
293                        &self.attenuation,
294                        self.muffle_cutoff_hz,
295                        self.min_gain,
296                    );
297
298                    self.offset = offset.normalize_or(Vec3::Y);
299                }
300                HrtfNodePatch::MuffleCutoffHz(muffle) => {
301                    self.muffle_cutoff_hz = muffle;
302                }
303                HrtfNodePatch::DistanceAttenuation(a) => {
304                    self.attenuation.apply(a);
305                }
306                HrtfNodePatch::SmoothSeconds(s) => {
307                    self.attenuation_processor
308                        .set_smooth_seconds(s, proc_info.sample_rate);
309                }
310                HrtfNodePatch::MinGain(g) => {
311                    self.min_gain = g;
312                }
313                HrtfNodePatch::CoeffUpdateFactor(c) => {
314                    self.attenuation_processor.set_coeff_update_factor(c);
315                }
316            }
317        }
318
319        if proc_info.in_silence_mask.all_channels_silent(inputs.len()) {
320            self.attenuation_processor.reset();
321
322            return ProcessStatus::ClearAllOutputs;
323        }
324
325        for frame in 0..proc_info.frames {
326            let mut downmixed = 0.0;
327            for channel in inputs {
328                downmixed += channel[frame];
329            }
330            downmixed /= inputs.len() as f32;
331
332            self.fft_input.push(downmixed);
333
334            // Buffer full, process FFT
335            if self.fft_input.len() == self.fft_input.capacity() {
336                let fft_len = self.fft_input.len();
337
338                let output_start = self.fft_output.len();
339                self.fft_output
340                    .extend(std::iter::repeat_n((0.0, 0.0), fft_len));
341
342                // let (left, right) = outputs.split_at_mut(1);
343                let context = HrtfContext {
344                    source: &self.fft_input,
345                    output: &mut self.fft_output[output_start..],
346                    new_sample_vector: hrtf::Vec3::new(self.offset.x, self.offset.y, self.offset.z),
347                    prev_sample_vector: hrtf::Vec3::new(
348                        previous_vector.x,
349                        previous_vector.y,
350                        previous_vector.z,
351                    ),
352                    prev_left_samples: &mut self.prev_left_samples,
353                    prev_right_samples: &mut self.prev_right_samples,
354                    new_distance_gain: 1.0,
355                    prev_distance_gain: 1.0,
356                };
357
358                self.renderer.process_samples(context);
359
360                // in case we call this multiple times
361                previous_vector = self.offset;
362                self.fft_input.clear();
363            }
364        }
365
366        for (i, (left, right)) in self
367            .fft_output
368            .drain(..proc_info.frames.min(self.fft_output.len()))
369            .enumerate()
370        {
371            outputs[0][i] = left;
372            outputs[1][i] = right;
373        }
374
375        let (left, rest) = outputs.split_first_mut().unwrap();
376        let clear_outputs = self.attenuation_processor.process(
377            proc_info.frames,
378            left,
379            rest[0],
380            proc_info.sample_rate_recip,
381        );
382
383        if clear_outputs {
384            self.attenuation_processor.reset();
385            ProcessStatus::ClearAllOutputs
386        } else {
387            ProcessStatus::OutputsModified
388        }
389    }
390
391    fn new_stream(
392        &mut self,
393        stream_info: &firewheel::StreamInfo,
394        _store: &mut firewheel::node::ProcStreamCtx,
395    ) {
396        if stream_info.prev_sample_rate != stream_info.sample_rate {
397            let sample_rate = stream_info.sample_rate.get();
398
399            let sphere = self
400                .sphere_source
401                .get_sphere(sample_rate)
402                .expect("HRIR data should be in a valid format");
403
404            let renderer =
405                HrtfProcessor::new(sphere, self.fft_size.slice_count, self.fft_size.slice_len);
406
407            self.renderer = renderer;
408        }
409    }
410}