1#![cfg_attr(docsrs, feature(doc_cfg))]
15#![warn(missing_debug_implementations)]
16#![warn(missing_docs)]
17
18use firewheel::{
19 channel_config::{ChannelConfig, NonZeroChannelCount},
20 diff::{Diff, Patch},
21 dsp::{coeff_update::CoeffUpdateFactor, distance_attenuation::DistanceAttenuatorStereoDsp},
22 event::ProcEvents,
23 node::{
24 AudioNode, AudioNodeInfo, AudioNodeProcessor, ProcBuffers, ProcExtra, ProcInfo,
25 ProcessStatus,
26 },
27};
28use glam::Vec3;
29use hrtf::{HrirSphere, HrtfContext, HrtfProcessor};
30use std::io::Cursor;
31
32mod subjects;
33
34pub use firewheel::dsp::distance_attenuation::{DistanceAttenuation, DistanceModel};
35pub use subjects::{Subject, SubjectBytes};
36
37#[derive(Debug, Clone, Diff, Patch)]
49#[cfg_attr(feature = "bevy", derive(bevy_ecs::component::Component))]
50#[cfg_attr(feature = "bevy_reflect", derive(bevy_reflect::Reflect))]
51pub struct HrtfNode {
52 pub offset: Vec3,
54
55 pub muffle_cutoff_hz: f32,
66
67 pub distance_attenuation: DistanceAttenuation,
69
70 pub smooth_seconds: f32,
74
75 pub min_gain: f32,
80
81 pub coeff_update_factor: CoeffUpdateFactor,
93}
94
95impl Default for HrtfNode {
96 fn default() -> Self {
97 Self {
98 offset: Vec3::ZERO,
99 muffle_cutoff_hz: 20480.0,
100 distance_attenuation: Default::default(),
101 smooth_seconds: 0.015,
102 min_gain: 0.0001,
103 coeff_update_factor: CoeffUpdateFactor(5),
104 }
105 }
106}
107
108#[derive(Debug, Clone, PartialEq)]
110#[cfg_attr(feature = "bevy", derive(bevy_ecs::component::Component))]
111#[cfg_attr(feature = "bevy_reflect", derive(bevy_reflect::Reflect))]
112pub struct HrtfConfig {
113 pub input_channels: NonZeroChannelCount,
120
121 pub hrir_sphere: HrirSource,
130
131 pub fft_size: FftSize,
134}
135
136impl Default for HrtfConfig {
137 fn default() -> Self {
138 Self {
139 input_channels: NonZeroChannelCount::STEREO,
140 hrir_sphere: Subject::Irc1040.into(),
141 fft_size: FftSize::default(),
142 }
143 }
144}
145
146#[derive(Debug, Clone, PartialEq)]
152#[cfg_attr(feature = "bevy_reflect", derive(bevy_reflect::Reflect))]
153pub struct FftSize {
154 pub slice_count: usize,
158
159 pub slice_len: usize,
163}
164
165impl Default for FftSize {
166 fn default() -> Self {
167 Self {
168 slice_count: 4,
169 slice_len: 128,
170 }
171 }
172}
173
174#[derive(Debug, Clone, PartialEq)]
176#[cfg_attr(feature = "bevy_reflect", derive(bevy_reflect::Reflect))]
177pub enum HrirSource {
178 Embedded(Subject),
180 InMemory(SubjectBytes),
182}
183
184impl HrirSource {
185 fn get_sphere(&self, sample_rate: u32) -> Result<HrirSphere, hrtf::HrtfError> {
186 match &self {
187 HrirSource::Embedded(subject) => HrirSphere::new(Cursor::new(*subject), sample_rate),
188 HrirSource::InMemory(subject) => {
189 HrirSphere::new(Cursor::new(subject.clone()), sample_rate)
190 }
191 }
192 }
193}
194
195impl From<Subject> for HrirSource {
196 fn from(value: Subject) -> Self {
197 Self::Embedded(value)
198 }
199}
200
201impl From<SubjectBytes> for HrirSource {
202 fn from(value: SubjectBytes) -> Self {
203 Self::InMemory(value)
204 }
205}
206
207impl AudioNode for HrtfNode {
208 type Configuration = HrtfConfig;
209
210 fn info(&self, config: &Self::Configuration) -> AudioNodeInfo {
211 AudioNodeInfo::new()
212 .debug_name("hrtf node")
213 .channel_config(ChannelConfig::new(config.input_channels.get(), 2))
214 }
215
216 fn construct_processor(
217 &self,
218 config: &Self::Configuration,
219 cx: firewheel::node::ConstructProcessorContext,
220 ) -> impl firewheel::node::AudioNodeProcessor {
221 let sample_rate = cx.stream_info.sample_rate.get();
222
223 let sphere = config
224 .hrir_sphere
225 .get_sphere(sample_rate)
226 .expect("HRIR data should be in a valid format");
227
228 let fft_buffer_len = config.fft_size.slice_count * config.fft_size.slice_len;
229
230 let renderer = HrtfProcessor::new(
231 sphere,
232 config.fft_size.slice_count,
233 config.fft_size.slice_len,
234 );
235
236 let buffer_size = cx.stream_info.max_block_frames.get() as usize;
237 FyroxHrtfProcessor {
238 renderer,
239 attenuation: self.distance_attenuation,
240 attenuation_processor: DistanceAttenuatorStereoDsp::new(
241 firewheel::param::smoother::SmootherConfig {
242 smooth_seconds: self.smooth_seconds,
243 ..Default::default()
244 },
245 cx.stream_info.sample_rate,
246 self.coeff_update_factor,
247 ),
248 muffle_cutoff_hz: self.muffle_cutoff_hz,
249 offset: self.offset,
250 min_gain: self.min_gain,
251 fft_input: Vec::with_capacity(fft_buffer_len),
252 fft_output: Vec::with_capacity(buffer_size.max(fft_buffer_len)),
253 prev_left_samples: Vec::with_capacity(fft_buffer_len),
254 prev_right_samples: Vec::with_capacity(fft_buffer_len),
255 sphere_source: config.hrir_sphere.clone(),
256 fft_size: config.fft_size.clone(),
257 }
258 }
259}
260
261struct FyroxHrtfProcessor {
262 renderer: HrtfProcessor,
263 offset: Vec3,
264 attenuation: DistanceAttenuation,
265 attenuation_processor: DistanceAttenuatorStereoDsp,
266 muffle_cutoff_hz: f32,
267 min_gain: f32,
268 fft_input: Vec<f32>,
269 fft_output: Vec<(f32, f32)>,
270 prev_left_samples: Vec<f32>,
271 prev_right_samples: Vec<f32>,
272 sphere_source: HrirSource,
273 fft_size: FftSize,
274}
275
276impl AudioNodeProcessor for FyroxHrtfProcessor {
277 fn process(
278 &mut self,
279 proc_info: &ProcInfo,
280 ProcBuffers { inputs, outputs }: ProcBuffers,
281 events: &mut ProcEvents,
282 _: &mut ProcExtra,
283 ) -> ProcessStatus {
284 let mut previous_vector = self.offset;
285
286 for patch in events.drain_patches::<HrtfNode>() {
287 match patch {
288 HrtfNodePatch::Offset(offset) => {
289 let distance = offset.length().max(0.01);
290
291 self.attenuation_processor.compute_values(
292 distance,
293 &self.attenuation,
294 self.muffle_cutoff_hz,
295 self.min_gain,
296 );
297
298 self.offset = offset.normalize_or(Vec3::Y);
299 }
300 HrtfNodePatch::MuffleCutoffHz(muffle) => {
301 self.muffle_cutoff_hz = muffle;
302 }
303 HrtfNodePatch::DistanceAttenuation(a) => {
304 self.attenuation.apply(a);
305 }
306 HrtfNodePatch::SmoothSeconds(s) => {
307 self.attenuation_processor
308 .set_smooth_seconds(s, proc_info.sample_rate);
309 }
310 HrtfNodePatch::MinGain(g) => {
311 self.min_gain = g;
312 }
313 HrtfNodePatch::CoeffUpdateFactor(c) => {
314 self.attenuation_processor.set_coeff_update_factor(c);
315 }
316 }
317 }
318
319 if proc_info.in_silence_mask.all_channels_silent(inputs.len()) {
320 self.attenuation_processor.reset();
321
322 return ProcessStatus::ClearAllOutputs;
323 }
324
325 for frame in 0..proc_info.frames {
326 let mut downmixed = 0.0;
327 for channel in inputs {
328 downmixed += channel[frame];
329 }
330 downmixed /= inputs.len() as f32;
331
332 self.fft_input.push(downmixed);
333
334 if self.fft_input.len() == self.fft_input.capacity() {
336 let fft_len = self.fft_input.len();
337
338 let output_start = self.fft_output.len();
339 self.fft_output
340 .extend(std::iter::repeat_n((0.0, 0.0), fft_len));
341
342 let context = HrtfContext {
344 source: &self.fft_input,
345 output: &mut self.fft_output[output_start..],
346 new_sample_vector: hrtf::Vec3::new(self.offset.x, self.offset.y, self.offset.z),
347 prev_sample_vector: hrtf::Vec3::new(
348 previous_vector.x,
349 previous_vector.y,
350 previous_vector.z,
351 ),
352 prev_left_samples: &mut self.prev_left_samples,
353 prev_right_samples: &mut self.prev_right_samples,
354 new_distance_gain: 1.0,
355 prev_distance_gain: 1.0,
356 };
357
358 self.renderer.process_samples(context);
359
360 previous_vector = self.offset;
362 self.fft_input.clear();
363 }
364 }
365
366 for (i, (left, right)) in self
367 .fft_output
368 .drain(..proc_info.frames.min(self.fft_output.len()))
369 .enumerate()
370 {
371 outputs[0][i] = left;
372 outputs[1][i] = right;
373 }
374
375 let (left, rest) = outputs.split_first_mut().unwrap();
376 let clear_outputs = self.attenuation_processor.process(
377 proc_info.frames,
378 left,
379 rest[0],
380 proc_info.sample_rate_recip,
381 );
382
383 if clear_outputs {
384 self.attenuation_processor.reset();
385 ProcessStatus::ClearAllOutputs
386 } else {
387 ProcessStatus::OutputsModified
388 }
389 }
390
391 fn new_stream(
392 &mut self,
393 stream_info: &firewheel::StreamInfo,
394 _store: &mut firewheel::node::ProcStreamCtx,
395 ) {
396 if stream_info.prev_sample_rate != stream_info.sample_rate {
397 let sample_rate = stream_info.sample_rate.get();
398
399 let sphere = self
400 .sphere_source
401 .get_sphere(sample_rate)
402 .expect("HRIR data should be in a valid format");
403
404 let renderer =
405 HrtfProcessor::new(sphere, self.fft_size.slice_count, self.fft_size.slice_len);
406
407 self.renderer = renderer;
408 }
409 }
410}