1pub mod ai_personalization;
4pub mod database;
5
6use crate::types::Position3D;
7pub use database::{
8 DatabaseConfig, DatabaseStatistics, HrtfDatabaseManager, HrtfMeasurement as DbHrtfMeasurement,
9 HrtfPosition, InterpolationMethod as DbInterpolationMethod, PersonalizedHrtf, StorageFormat,
10};
11use scirs2_core::ndarray::Array1;
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14use std::path::PathBuf;
15use std::sync::Arc;
16
17pub struct HrtfProcessor {
19 database: Arc<HrtfDatabase>,
21 #[allow(dead_code)]
23 buffer_size: usize,
24 #[allow(dead_code)]
26 overlap_left: Array1<f32>,
27 #[allow(dead_code)]
29 overlap_right: Array1<f32>,
30 config: HrtfConfig,
32}
33
34#[derive(Clone)]
36pub struct HrtfDatabase {
37 metadata: HrtfMetadata,
39 left_responses: HashMap<(i32, i32), Array1<f32>>,
41 right_responses: HashMap<(i32, i32), Array1<f32>>,
43 #[allow(dead_code)]
45 #[allow(clippy::type_complexity)]
46 frequency_responses: Option<HashMap<(i32, i32), (Array1<f32>, Array1<f32>)>>,
47 #[allow(dead_code)]
49 #[allow(clippy::type_complexity)]
50 distance_responses: Option<HashMap<(i32, i32, u32), (Array1<f32>, Array1<f32>)>>,
51}
52
53#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct HrtfConfig {
56 pub sample_rate: u32,
58 pub hrir_length: usize,
60 pub crossfade_time: f32,
62 pub enable_distance_modeling: bool,
64 pub interpolation_method: InterpolationMethod,
66 pub head_circumference: Option<f32>,
68 pub near_field_distance: f32,
70 pub far_field_distance: f32,
72 pub enable_air_absorption: bool,
74 pub temperature: f32,
76 pub humidity: f32,
78 pub enable_simd: bool,
80}
81
82#[derive(Debug, Clone, Serialize, Deserialize)]
84pub struct HrtfMetadata {
85 pub name: String,
87 pub sample_rate: u32,
89 pub hrir_length: usize,
91 pub azimuth_angles: Vec<i32>,
93 pub elevation_angles: Vec<i32>,
95 pub distances: Option<Vec<f32>>,
97 pub subject_info: Option<SubjectInfo>,
99}
100
101#[derive(Debug, Clone, Serialize, Deserialize)]
103pub struct SubjectInfo {
104 pub head_circumference: f32,
106 pub head_width: f32,
108 pub head_height: f32,
110 pub ear_height: f32,
112 pub shoulder_width: f32,
114}
115
116#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
118pub enum InterpolationMethod {
119 Nearest,
121 Bilinear,
123 Spherical,
125 Weighted,
127}
128
129#[derive(Debug, Clone, Copy)]
131pub struct SphericalCoordinates {
132 pub azimuth: f32,
134 pub elevation: f32,
136 pub distance: f32,
138}
139
140#[derive(Debug, Clone)]
142pub struct HrtfMeasurement {
143 pub coordinates: SphericalCoordinates,
145 pub left_hrir: Array1<f32>,
147 pub right_hrir: Array1<f32>,
149}
150
151impl HrtfProcessor {
152 pub async fn new(database_path: Option<PathBuf>) -> crate::Result<Self> {
154 let database = if let Some(path) = database_path {
155 HrtfDatabase::load_from_file(&path).await?
156 } else {
157 HrtfDatabase::load_default().await?
158 };
159
160 let config = HrtfConfig::default();
161 let buffer_size = config.hrir_length * 2; Ok(Self {
164 database: Arc::new(database),
165 buffer_size,
166 overlap_left: Array1::zeros(buffer_size),
167 overlap_right: Array1::zeros(buffer_size),
168 config,
169 })
170 }
171
172 pub async fn new_default() -> crate::Result<Self> {
174 Self::new(None).await
175 }
176
177 pub async fn with_config(
179 database_path: Option<PathBuf>,
180 config: HrtfConfig,
181 ) -> crate::Result<Self> {
182 let database = if let Some(path) = database_path {
183 HrtfDatabase::load_from_file(&path).await?
184 } else {
185 HrtfDatabase::load_default().await?
186 };
187
188 let buffer_size = config.hrir_length * 2;
189
190 Ok(Self {
191 database: Arc::new(database),
192 buffer_size,
193 overlap_left: Array1::zeros(buffer_size),
194 overlap_right: Array1::zeros(buffer_size),
195 config,
196 })
197 }
198
199 pub async fn process_position(
201 &self,
202 input: &Array1<f32>,
203 left_output: &mut Array1<f32>,
204 right_output: &mut Array1<f32>,
205 position: &Position3D,
206 ) -> crate::Result<()> {
207 let spherical = self.cartesian_to_spherical(position);
209
210 let (mut left_hrir, mut right_hrir) = self.get_hrtf(&spherical)?;
212
213 if self.config.enable_distance_modeling {
215 self.apply_distance_modeling(&mut left_hrir, &mut right_hrir, &spherical)?;
216 }
217
218 if self.config.enable_air_absorption && spherical.distance > self.config.near_field_distance
220 {
221 self.apply_air_absorption(&mut left_hrir, &mut right_hrir, spherical.distance)?;
222 }
223
224 self.convolve_hrtf(input, &left_hrir, &right_hrir, left_output, right_output)?;
226
227 Ok(())
228 }
229
230 pub async fn process_position_smooth(
232 &self,
233 input: &Array1<f32>,
234 left_output: &mut Array1<f32>,
235 right_output: &mut Array1<f32>,
236 start_position: &Position3D,
237 end_position: &Position3D,
238 progress: f32,
239 ) -> crate::Result<()> {
240 let current_position = Position3D::new(
242 start_position.x * (1.0 - progress) + end_position.x * progress,
243 start_position.y * (1.0 - progress) + end_position.y * progress,
244 start_position.z * (1.0 - progress) + end_position.z * progress,
245 );
246
247 self.process_position(input, left_output, right_output, ¤t_position)
248 .await
249 }
250
251 pub async fn process_realtime_chunk(
253 &mut self,
254 input: &Array1<f32>,
255 left_output: &mut Array1<f32>,
256 right_output: &mut Array1<f32>,
257 position: &Position3D,
258 ) -> crate::Result<()> {
259 let chunk_size = input.len();
260 let hrir_len = self.config.hrir_length;
261
262 if left_output.len() != chunk_size || right_output.len() != chunk_size {
264 return Err(crate::Error::processing(
265 "Output buffer size must match input chunk size",
266 ));
267 }
268
269 let spherical = self.cartesian_to_spherical(position);
271
272 let (left_hrir, right_hrir) = self.get_hrtf(&spherical)?;
274
275 let conv_len = chunk_size + hrir_len - 1;
277 let mut left_conv = Array1::zeros(conv_len);
278 let mut right_conv = Array1::zeros(conv_len);
279
280 self.convolve_hrtf(
282 input,
283 &left_hrir,
284 &right_hrir,
285 &mut left_conv,
286 &mut right_conv,
287 )?;
288
289 for i in 0..chunk_size {
291 left_output[i] = left_conv[i] + self.overlap_left[i];
292 right_output[i] = right_conv[i] + self.overlap_right[i];
293 }
294
295 self.overlap_left.fill(0.0);
297 self.overlap_right.fill(0.0);
298
299 let tail_start = chunk_size;
300 let tail_len = (conv_len - chunk_size).min(self.overlap_left.len());
301
302 for i in 0..tail_len {
303 self.overlap_left[i] = left_conv[tail_start + i];
304 self.overlap_right[i] = right_conv[tail_start + i];
305 }
306
307 Ok(())
308 }
309
310 pub fn reset_buffers(&mut self) {
312 self.overlap_left.fill(0.0);
313 self.overlap_right.fill(0.0);
314 }
315
316 pub async fn process_crossfade(
318 &self,
319 input: &Array1<f32>,
320 left_output: &mut Array1<f32>,
321 right_output: &mut Array1<f32>,
322 positions: &[(Position3D, f32)], ) -> crate::Result<()> {
324 left_output.fill(0.0);
325 right_output.fill(0.0);
326
327 let mut temp_left = Array1::zeros(input.len());
328 let mut temp_right = Array1::zeros(input.len());
329
330 for (position, weight) in positions {
331 self.process_position(input, &mut temp_left, &mut temp_right, position)
333 .await?;
334
335 for i in 0..left_output.len() {
337 left_output[i] += temp_left[i] * weight;
338 right_output[i] += temp_right[i] * weight;
339 }
340 }
341
342 Ok(())
343 }
344
345 fn get_hrtf(&self, coords: &SphericalCoordinates) -> crate::Result<(Array1<f32>, Array1<f32>)> {
347 match self.config.interpolation_method {
348 InterpolationMethod::Nearest => self.get_nearest_hrtf(coords),
349 InterpolationMethod::Bilinear => self.get_bilinear_hrtf(coords),
350 InterpolationMethod::Spherical => self.get_spherical_hrtf(coords),
351 InterpolationMethod::Weighted => self.get_weighted_hrtf(coords),
352 }
353 }
354
355 fn get_nearest_hrtf(
357 &self,
358 coords: &SphericalCoordinates,
359 ) -> crate::Result<(Array1<f32>, Array1<f32>)> {
360 let azimuth = coords.azimuth.round() as i32;
361 let elevation = coords.elevation.round() as i32;
362
363 let closest_azimuth =
365 self.find_closest_angle(azimuth, &self.database.metadata.azimuth_angles);
366 let closest_elevation =
367 self.find_closest_angle(elevation, &self.database.metadata.elevation_angles);
368
369 let key = (closest_azimuth, closest_elevation);
370
371 let left_hrir = self.database.left_responses.get(&key).ok_or_else(|| {
372 crate::Error::LegacyHrtf(format!(
373 "No HRTF found for angles ({closest_azimuth}, {closest_elevation})"
374 ))
375 })?;
376 let right_hrir = self.database.right_responses.get(&key).ok_or_else(|| {
377 crate::Error::LegacyHrtf(format!(
378 "No HRTF found for angles ({closest_azimuth}, {closest_elevation})"
379 ))
380 })?;
381
382 Ok((left_hrir.clone(), right_hrir.clone()))
383 }
384
385 fn get_bilinear_hrtf(
387 &self,
388 coords: &SphericalCoordinates,
389 ) -> crate::Result<(Array1<f32>, Array1<f32>)> {
390 let az_low = self.find_lower_angle(
392 coords.azimuth as i32,
393 &self.database.metadata.azimuth_angles,
394 );
395 let az_high = self.find_higher_angle(
396 coords.azimuth as i32,
397 &self.database.metadata.azimuth_angles,
398 );
399 let el_low = self.find_lower_angle(
400 coords.elevation as i32,
401 &self.database.metadata.elevation_angles,
402 );
403 let el_high = self.find_higher_angle(
404 coords.elevation as i32,
405 &self.database.metadata.elevation_angles,
406 );
407
408 let hrtf_00 = self.get_hrtf_at_angles(az_low, el_low)?;
410 let hrtf_01 = self.get_hrtf_at_angles(az_low, el_high)?;
411 let hrtf_10 = self.get_hrtf_at_angles(az_high, el_low)?;
412 let hrtf_11 = self.get_hrtf_at_angles(az_high, el_high)?;
413
414 let az_weight = if az_high != az_low {
416 (coords.azimuth - az_low as f32) / (az_high - az_low) as f32
417 } else {
418 0.0
419 };
420 let el_weight = if el_high != el_low {
421 (coords.elevation - el_low as f32) / (el_high - el_low) as f32
422 } else {
423 0.0
424 };
425
426 let left_hrir = self.interpolate_hrtf(&[
428 (&hrtf_00.0, (1.0 - az_weight) * (1.0 - el_weight)),
429 (&hrtf_01.0, (1.0 - az_weight) * el_weight),
430 (&hrtf_10.0, az_weight * (1.0 - el_weight)),
431 (&hrtf_11.0, az_weight * el_weight),
432 ]);
433
434 let right_hrir = self.interpolate_hrtf(&[
435 (&hrtf_00.1, (1.0 - az_weight) * (1.0 - el_weight)),
436 (&hrtf_01.1, (1.0 - az_weight) * el_weight),
437 (&hrtf_10.1, az_weight * (1.0 - el_weight)),
438 (&hrtf_11.1, az_weight * el_weight),
439 ]);
440
441 Ok((left_hrir, right_hrir))
442 }
443
444 fn get_spherical_hrtf(
446 &self,
447 coords: &SphericalCoordinates,
448 ) -> crate::Result<(Array1<f32>, Array1<f32>)> {
449 let mut left_sum = Array1::zeros(self.config.hrir_length);
450 let mut right_sum = Array1::zeros(self.config.hrir_length);
451 let mut weight_sum = 0.0;
452
453 let mut nearest_points = Vec::new();
455
456 for (&(az, el), left_hrir) in &self.database.left_responses {
457 let Some(right_hrir) = self.database.right_responses.get(&(az, el)) else {
458 continue;
459 };
460
461 let angular_distance = self.calculate_angular_distance(
463 coords.azimuth,
464 coords.elevation,
465 az as f32,
466 el as f32,
467 );
468
469 nearest_points.push((angular_distance, left_hrir, right_hrir));
470 }
471
472 nearest_points.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
474 nearest_points.truncate(4);
475
476 for (distance, left_hrir, right_hrir) in nearest_points {
478 let weight = if distance < 0.01 {
479 1.0 } else {
481 1.0 / (distance.to_radians().sin() + 0.001)
483 };
484
485 weight_sum += weight;
486
487 for i in 0..left_sum.len().min(left_hrir.len()) {
488 left_sum[i] += left_hrir[i] * weight;
489 right_sum[i] += right_hrir[i] * weight;
490 }
491 }
492
493 if weight_sum > 0.0 {
495 left_sum /= weight_sum;
496 right_sum /= weight_sum;
497 }
498
499 Ok((left_sum, right_sum))
500 }
501
502 fn get_weighted_hrtf(
504 &self,
505 coords: &SphericalCoordinates,
506 ) -> crate::Result<(Array1<f32>, Array1<f32>)> {
507 let mut left_sum = Array1::zeros(self.config.hrir_length);
508 let mut right_sum = Array1::zeros(self.config.hrir_length);
509 let mut weight_sum = 0.0;
510
511 for (&(az, el), left_hrir) in &self.database.left_responses {
513 let Some(right_hrir) = self.database.right_responses.get(&(az, el)) else {
514 continue;
515 };
516
517 let angular_distance = self.calculate_angular_distance(
519 coords.azimuth,
520 coords.elevation,
521 az as f32,
522 el as f32,
523 );
524
525 if angular_distance < 30.0 {
526 let weight = 1.0 / (1.0 + angular_distance);
528 weight_sum += weight;
529
530 for i in 0..left_sum.len() {
531 left_sum[i] += left_hrir[i] * weight;
532 right_sum[i] += right_hrir[i] * weight;
533 }
534 }
535 }
536
537 if weight_sum > 0.0 {
538 left_sum /= weight_sum;
539 right_sum /= weight_sum;
540 }
541
542 Ok((left_sum, right_sum))
543 }
544
545 fn convolve_hrtf(
547 &self,
548 input: &Array1<f32>,
549 left_hrir: &Array1<f32>,
550 right_hrir: &Array1<f32>,
551 left_output: &mut Array1<f32>,
552 right_output: &mut Array1<f32>,
553 ) -> crate::Result<()> {
554 let input_len = input.len();
555 let hrir_len = left_hrir.len();
556
557 if input_len < 64 || hrir_len < 64 {
559 return self.convolve_time_domain(
560 input,
561 left_hrir,
562 right_hrir,
563 left_output,
564 right_output,
565 );
566 }
567
568 self.convolve_frequency_domain(input, left_hrir, right_hrir, left_output, right_output)
570 }
571
572 fn convolve_time_domain(
574 &self,
575 input: &Array1<f32>,
576 left_hrir: &Array1<f32>,
577 right_hrir: &Array1<f32>,
578 left_output: &mut Array1<f32>,
579 right_output: &mut Array1<f32>,
580 ) -> crate::Result<()> {
581 left_output.fill(0.0);
582 right_output.fill(0.0);
583
584 for (i, &sample) in input.iter().enumerate() {
585 for (j, &hrir_sample) in left_hrir.iter().enumerate() {
586 if i + j < left_output.len() {
587 left_output[i + j] += sample * hrir_sample;
588 }
589 }
590 for (j, &hrir_sample) in right_hrir.iter().enumerate() {
591 if i + j < right_output.len() {
592 right_output[i + j] += sample * hrir_sample;
593 }
594 }
595 }
596
597 Ok(())
598 }
599
600 fn convolve_frequency_domain(
602 &self,
603 input: &Array1<f32>,
604 left_hrir: &Array1<f32>,
605 right_hrir: &Array1<f32>,
606 left_output: &mut Array1<f32>,
607 right_output: &mut Array1<f32>,
608 ) -> crate::Result<()> {
609 let input_len = input.len();
610 let hrir_len = left_hrir.len();
611 let conv_len = input_len + hrir_len - 1;
612
613 let fft_len = conv_len.next_power_of_two();
615
616 let input_complex: Vec<scirs2_core::Complex<f64>> = input
618 .iter()
619 .map(|&x| scirs2_core::Complex::new(x as f64, 0.0))
620 .chain(std::iter::repeat(scirs2_core::Complex::new(0.0, 0.0)))
621 .take(fft_len)
622 .collect();
623
624 let left_hrir_complex: Vec<scirs2_core::Complex<f64>> = left_hrir
625 .iter()
626 .map(|&x| scirs2_core::Complex::new(x as f64, 0.0))
627 .chain(std::iter::repeat(scirs2_core::Complex::new(0.0, 0.0)))
628 .take(fft_len)
629 .collect();
630
631 let right_hrir_complex: Vec<scirs2_core::Complex<f64>> = right_hrir
632 .iter()
633 .map(|&x| scirs2_core::Complex::new(x as f64, 0.0))
634 .chain(std::iter::repeat(scirs2_core::Complex::new(0.0, 0.0)))
635 .take(fft_len)
636 .collect();
637
638 let input_spectrum = scirs2_fft::fft(&input_complex, None)
640 .map_err(|e| crate::Error::LegacyProcessing(format!("FFT error: {e}")))?;
641
642 let left_hrir_spectrum = scirs2_fft::fft(&left_hrir_complex, None)
643 .map_err(|e| crate::Error::LegacyProcessing(format!("FFT error: {e}")))?;
644
645 let right_hrir_spectrum = scirs2_fft::fft(&right_hrir_complex, None)
646 .map_err(|e| crate::Error::LegacyProcessing(format!("FFT error: {e}")))?;
647
648 let left_result_spectrum: Vec<scirs2_core::Complex<f64>> = input_spectrum
650 .iter()
651 .zip(left_hrir_spectrum.iter())
652 .map(|(a, b)| a * b)
653 .collect();
654
655 let right_result_spectrum: Vec<scirs2_core::Complex<f64>> = input_spectrum
656 .iter()
657 .zip(right_hrir_spectrum.iter())
658 .map(|(a, b)| a * b)
659 .collect();
660
661 let left_result_time = scirs2_fft::ifft(&left_result_spectrum, None)
663 .map_err(|e| crate::Error::LegacyProcessing(format!("IFFT error: {e}")))?;
664
665 let right_result_time = scirs2_fft::ifft(&right_result_spectrum, None)
666 .map_err(|e| crate::Error::LegacyProcessing(format!("IFFT error: {e}")))?;
667
668 let output_len = left_output.len().min(conv_len);
670
671 for i in 0..output_len {
672 left_output[i] = left_result_time[i].re as f32;
673 right_output[i] = right_result_time[i].re as f32;
674 }
675
676 Ok(())
677 }
678
679 fn cartesian_to_spherical(&self, position: &Position3D) -> SphericalCoordinates {
681 let distance =
682 (position.x * position.x + position.y * position.y + position.z * position.z).sqrt();
683
684 let azimuth = if distance > 0.0 {
685 position.z.atan2(position.x).to_degrees()
686 } else {
687 0.0
688 };
689
690 let elevation = if distance > 0.0 {
691 (position.y / distance).asin().to_degrees()
692 } else {
693 0.0
694 };
695
696 SphericalCoordinates {
697 azimuth,
698 elevation,
699 distance,
700 }
701 }
702
703 fn find_closest_angle(&self, target: i32, available: &[i32]) -> i32 {
705 available
706 .iter()
707 .min_by_key(|&&angle| (angle - target).abs())
708 .copied()
709 .unwrap_or(0)
710 }
711
712 fn find_lower_angle(&self, target: i32, available: &[i32]) -> i32 {
714 available
715 .iter()
716 .filter(|&&angle| angle <= target)
717 .max()
718 .copied()
719 .unwrap_or(*available.first().unwrap_or(&0))
720 }
721
722 fn find_higher_angle(&self, target: i32, available: &[i32]) -> i32 {
724 available
725 .iter()
726 .filter(|&&angle| angle >= target)
727 .min()
728 .copied()
729 .unwrap_or(*available.last().unwrap_or(&0))
730 }
731
732 fn get_hrtf_at_angles(
734 &self,
735 azimuth: i32,
736 elevation: i32,
737 ) -> crate::Result<(Array1<f32>, Array1<f32>)> {
738 let key = (azimuth, elevation);
739 let left = self.database.left_responses.get(&key).ok_or_else(|| {
740 crate::Error::LegacyHrtf(format!("No HRTF for angles ({azimuth}, {elevation})"))
741 })?;
742 let right = self.database.right_responses.get(&key).ok_or_else(|| {
743 crate::Error::LegacyHrtf(format!("No HRTF for angles ({azimuth}, {elevation})"))
744 })?;
745 Ok((left.clone(), right.clone()))
746 }
747
748 fn interpolate_hrtf(&self, weighted_hrirs: &[(&Array1<f32>, f32)]) -> Array1<f32> {
750 let mut result = Array1::zeros(self.config.hrir_length);
751
752 for (hrir, weight) in weighted_hrirs {
753 for i in 0..result.len().min(hrir.len()) {
754 result[i] += hrir[i] * weight;
755 }
756 }
757
758 result
759 }
760
761 fn calculate_angular_distance(&self, az1: f32, el1: f32, az2: f32, el2: f32) -> f32 {
763 let az1_rad = az1.to_radians();
764 let el1_rad = el1.to_radians();
765 let az2_rad = az2.to_radians();
766 let el2_rad = el2.to_radians();
767
768 let cos_distance = el1_rad.sin() * el2_rad.sin()
770 + el1_rad.cos() * el2_rad.cos() * (az1_rad - az2_rad).cos();
771
772 cos_distance.clamp(-1.0, 1.0).acos().to_degrees()
773 }
774
775 fn apply_distance_modeling(
777 &self,
778 left_hrir: &mut Array1<f32>,
779 right_hrir: &mut Array1<f32>,
780 coords: &SphericalCoordinates,
781 ) -> crate::Result<()> {
782 let distance = coords.distance.max(0.01); let attenuation = 1.0 / distance;
786
787 let near_field_factor = if distance < self.config.near_field_distance {
789 let proximity_boost = (self.config.near_field_distance / distance).powf(0.3);
791 proximity_boost.min(3.0) } else {
793 1.0
794 };
795
796 let far_field_factor = if distance > self.config.far_field_distance {
798 0.8 + 0.2 * (self.config.far_field_distance / distance)
800 } else {
801 1.0
802 };
803
804 let total_gain = attenuation * near_field_factor * far_field_factor;
806
807 left_hrir.mapv_inplace(|x| x * total_gain);
809 right_hrir.mapv_inplace(|x| x * total_gain);
810
811 if distance < self.config.near_field_distance {
813 self.apply_proximity_delay(left_hrir, right_hrir, coords)?;
814 }
815
816 Ok(())
817 }
818
819 fn apply_air_absorption(
821 &self,
822 left_hrir: &mut Array1<f32>,
823 right_hrir: &mut Array1<f32>,
824 distance: f32,
825 ) -> crate::Result<()> {
826 let temp_celsius = self.config.temperature;
830 let relative_humidity = self.config.humidity;
831
832 let temp_kelvin = temp_celsius + 273.15;
834 let temp_ratio = temp_kelvin / 293.15; let h_rel = relative_humidity * (101.325 * temp_ratio.powf(-5.0241));
838
839 let distance_factor = (-distance / 100.0).exp(); let temp_factor = temp_ratio.powf(-0.1);
843 let humidity_factor = 1.0 - relative_humidity * 0.1;
844
845 let absorption_factor = distance_factor * temp_factor * humidity_factor;
846
847 for i in 0..left_hrir.len() {
850 let freq_weight = if i as f32 / left_hrir.len() as f32 > 0.5 {
851 absorption_factor.powf(1.0 + i as f32 / left_hrir.len() as f32)
853 } else {
854 absorption_factor
855 };
856
857 left_hrir[i] *= freq_weight;
858 right_hrir[i] *= freq_weight;
859 }
860
861 Ok(())
862 }
863
864 fn apply_proximity_delay(
866 &self,
867 left_hrir: &mut Array1<f32>,
868 right_hrir: &mut Array1<f32>,
869 coords: &SphericalCoordinates,
870 ) -> crate::Result<()> {
871 let distance = coords.distance;
872 let azimuth_rad = coords.azimuth.to_radians();
873
874 let head_radius =
876 self.config.head_circumference.unwrap_or(57.0) / (2.0 * std::f32::consts::PI);
877 let sound_speed = 343.0; let itd_samples = if distance < head_radius * 2.0 {
881 let enhanced_itd =
883 (head_radius * azimuth_rad.sin() * (1.0 + azimuth_rad.cos())) / sound_speed;
884 (enhanced_itd * self.config.sample_rate as f32) as usize
885 } else {
886 0
887 };
888
889 if itd_samples > 0 && azimuth_rad.abs() > 0.1 {
891 let delay_samples = itd_samples.min(left_hrir.len() / 4);
892
893 if azimuth_rad > 0.0 {
894 self.apply_delay(left_hrir, delay_samples);
896 } else {
897 self.apply_delay(right_hrir, delay_samples);
899 }
900 }
901
902 Ok(())
903 }
904
905 fn apply_delay(&self, hrir: &mut Array1<f32>, delay_samples: usize) {
907 if delay_samples == 0 || delay_samples >= hrir.len() {
908 return;
909 }
910
911 let original = hrir.clone();
913 hrir.fill(0.0);
914
915 for i in 0..(hrir.len() - delay_samples) {
916 hrir[i + delay_samples] = original[i];
917 }
918 }
919}
920
921impl HrtfDatabase {
922 pub async fn load_from_file(path: &std::path::Path) -> crate::Result<Self> {
924 let extension = path
925 .extension()
926 .and_then(|ext| ext.to_str())
927 .map(|ext| ext.to_lowercase());
928
929 match extension.as_deref() {
930 Some("sofa") => Self::load_sofa_file(path).await,
931 Some("json") => Self::load_json_file(path).await,
932 Some("bin") | Some("hrtf") => Self::load_binary_file(path).await,
933 _ => {
934 tracing::warn!("Unknown HRTF file format, using default database");
935 Self::load_default().await
936 }
937 }
938 }
939
940 async fn load_sofa_file(path: &std::path::Path) -> crate::Result<Self> {
942 tracing::info!("Loading SOFA HRTF file: {:?}", path);
943
944 let content = tokio::fs::read_to_string(path)
947 .await
948 .map_err(|e| crate::Error::hrtf(&format!("Failed to read SOFA file: {e}")))?;
949
950 let mut metadata = HrtfMetadata {
951 name: "SOFA HRTF Database".to_string(),
952 sample_rate: 44100,
953 hrir_length: 512,
954 azimuth_angles: Vec::new(),
955 elevation_angles: Vec::new(),
956 distances: Some(vec![1.0]),
957 subject_info: Some(SubjectInfo {
958 head_circumference: 56.0,
959 head_width: 15.0,
960 head_height: 20.0,
961 ear_height: 10.0,
962 shoulder_width: 40.0,
963 }),
964 };
965
966 let mut left_responses = HashMap::new();
967 let mut right_responses = HashMap::new();
968 let mut current_section = "";
969 let mut current_measurement: Option<(i32, i32)> = None;
970 let mut left_hrir_data = Vec::new();
971 let mut right_hrir_data = Vec::new();
972
973 for line in content.lines() {
974 let line = line.trim();
975 if line.is_empty() || line.starts_with('#') {
976 continue;
977 }
978
979 if line.starts_with('[') && line.ends_with(']') {
981 current_section = &line[1..line.len() - 1];
982 continue;
983 }
984
985 match current_section {
986 "GLOBAL" => {
987 if let Some((key, value)) = line.split_once('=') {
988 let key = key.trim();
989 let value = value.trim();
990 match key {
991 "Data.SamplingRate" => {
992 if let Ok(rate) = value.parse::<u32>() {
993 metadata.sample_rate = rate;
994 }
995 }
996 "Data.IRLength" => {
997 if let Ok(length) = value.parse::<usize>() {
998 metadata.hrir_length = length;
999 }
1000 }
1001 "GLOBAL:DatabaseName" => {
1002 metadata.name = value.to_string();
1003 }
1004 "GLOBAL:ListenerShortName" => {
1005 metadata.subject_info = Some(SubjectInfo {
1006 head_circumference: 56.0,
1007 head_width: 15.0,
1008 head_height: 20.0,
1009 ear_height: 10.0,
1010 shoulder_width: 40.0,
1011 });
1012 }
1013 _ => {}
1014 }
1015 }
1016 }
1017 "POSITION" if line.starts_with("SourcePosition") => {
1018 if let Some((_, coords)) = line.split_once('=') {
1020 let parts: Vec<&str> = coords.split(',').collect();
1021 if parts.len() >= 3 {
1022 if let (Ok(azimuth), Ok(elevation), Ok(distance)) = (
1023 parts[0].trim().parse::<f32>(),
1024 parts[1].trim().parse::<f32>(),
1025 parts[2].trim().parse::<f32>(),
1026 ) {
1027 current_measurement = Some((azimuth as i32, elevation as i32));
1028 metadata.distances = Some(vec![distance]);
1029 }
1030 }
1031 }
1032 }
1033 "DATA_IR" => {
1034 if let Some((azimuth, elevation)) = current_measurement {
1035 if let Some(data_str) = line.strip_prefix("L:") {
1036 left_hrir_data = data_str
1038 .split_whitespace()
1039 .filter_map(|s| s.parse::<f32>().ok())
1040 .collect();
1041 } else if let Some(data_str) = line.strip_prefix("R:") {
1042 right_hrir_data = data_str
1044 .split_whitespace()
1045 .filter_map(|s| s.parse::<f32>().ok())
1046 .collect();
1047 }
1048
1049 if !left_hrir_data.is_empty() && !right_hrir_data.is_empty() {
1051 left_responses
1052 .insert((azimuth, elevation), Array1::from(left_hrir_data.clone()));
1053 right_responses.insert(
1054 (azimuth, elevation),
1055 Array1::from(right_hrir_data.clone()),
1056 );
1057
1058 if !metadata.azimuth_angles.contains(&azimuth) {
1059 metadata.azimuth_angles.push(azimuth);
1060 }
1061 if !metadata.elevation_angles.contains(&elevation) {
1062 metadata.elevation_angles.push(elevation);
1063 }
1064
1065 left_hrir_data.clear();
1066 right_hrir_data.clear();
1067 current_measurement = None;
1068 }
1069 }
1070 }
1071 _ => {}
1072 }
1073 }
1074
1075 metadata.azimuth_angles.sort();
1076 metadata.elevation_angles.sort();
1077
1078 if left_responses.is_empty() || right_responses.is_empty() {
1080 tracing::warn!("No valid HRTF measurements found in SOFA file, using enhanced default");
1081 return Self::load_enhanced_default().await;
1082 }
1083
1084 tracing::info!(
1085 "Successfully loaded {} HRTF measurements from SOFA file",
1086 left_responses.len()
1087 );
1088
1089 Ok(Self {
1090 metadata,
1091 left_responses,
1092 right_responses,
1093 frequency_responses: None,
1094 distance_responses: None,
1095 })
1096 }
1097
1098 async fn load_json_file(path: &std::path::Path) -> crate::Result<Self> {
1100 tracing::info!("Loading JSON HRTF file: {:?}", path);
1101
1102 let content = tokio::fs::read_to_string(path)
1103 .await
1104 .map_err(|e| crate::Error::hrtf(&format!("Failed to read JSON file: {e}")))?;
1105
1106 let json_data: serde_json::Value = serde_json::from_str(&content)
1107 .map_err(|e| crate::Error::hrtf(&format!("Failed to parse JSON: {e}")))?;
1108
1109 let metadata = if let Some(meta) = json_data.get("metadata") {
1111 HrtfMetadata {
1112 name: meta
1113 .get("name")
1114 .and_then(|v| v.as_str())
1115 .unwrap_or("JSON HRTF Database")
1116 .to_string(),
1117 sample_rate: meta
1118 .get("sample_rate")
1119 .and_then(|v| v.as_u64())
1120 .unwrap_or(44100) as u32,
1121 hrir_length: meta
1122 .get("hrir_length")
1123 .and_then(|v| v.as_u64())
1124 .unwrap_or(512) as usize,
1125 azimuth_angles: meta
1126 .get("azimuth_angles")
1127 .and_then(|v| v.as_array())
1128 .map(|arr| {
1129 arr.iter()
1130 .filter_map(|x| x.as_i64().map(|i| i as i32))
1131 .collect()
1132 })
1133 .unwrap_or_else(|| (-180..=180).step_by(15).collect()),
1134 elevation_angles: meta
1135 .get("elevation_angles")
1136 .and_then(|v| v.as_array())
1137 .map(|arr| {
1138 arr.iter()
1139 .filter_map(|x| x.as_i64().map(|i| i as i32))
1140 .collect()
1141 })
1142 .unwrap_or_else(|| (-40..=90).step_by(10).collect()),
1143 distances: meta
1144 .get("distance")
1145 .and_then(|v| v.as_f64())
1146 .map(|d| vec![d as f32]),
1147 subject_info: meta.get("subject_id").and_then(|v| v.as_str()).map(|_| {
1148 SubjectInfo {
1149 head_circumference: 56.0,
1150 head_width: 15.0,
1151 head_height: 20.0,
1152 ear_height: 10.0,
1153 shoulder_width: 40.0,
1154 }
1155 }),
1156 }
1157 } else {
1158 return Self::load_enhanced_default().await;
1159 };
1160
1161 let mut left_responses = HashMap::new();
1163 let mut right_responses = HashMap::new();
1164
1165 if let Some(measurements) = json_data.get("measurements").and_then(|v| v.as_array()) {
1166 for measurement in measurements {
1167 let azimuth = measurement
1168 .get("azimuth")
1169 .and_then(|v| v.as_i64())
1170 .unwrap_or(0) as i32;
1171 let elevation = measurement
1172 .get("elevation")
1173 .and_then(|v| v.as_i64())
1174 .unwrap_or(0) as i32;
1175
1176 if let Some(left_hrir) = measurement.get("left_hrir").and_then(|v| v.as_array()) {
1178 let left_data: Vec<f32> = left_hrir
1179 .iter()
1180 .filter_map(|x| x.as_f64().map(|f| f as f32))
1181 .collect();
1182 if !left_data.is_empty() {
1183 left_responses.insert((azimuth, elevation), Array1::from(left_data));
1184 }
1185 }
1186
1187 if let Some(right_hrir) = measurement.get("right_hrir").and_then(|v| v.as_array()) {
1189 let right_data: Vec<f32> = right_hrir
1190 .iter()
1191 .filter_map(|x| x.as_f64().map(|f| f as f32))
1192 .collect();
1193 if !right_data.is_empty() {
1194 right_responses.insert((azimuth, elevation), Array1::from(right_data));
1195 }
1196 }
1197 }
1198 }
1199
1200 if left_responses.is_empty() || right_responses.is_empty() {
1202 tracing::warn!("No valid HRTF measurements found in JSON file, using enhanced default");
1203 return Self::load_enhanced_default().await;
1204 }
1205
1206 tracing::info!(
1207 "Successfully loaded {} HRTF measurements from JSON",
1208 left_responses.len()
1209 );
1210
1211 Ok(Self {
1212 metadata,
1213 left_responses,
1214 right_responses,
1215 frequency_responses: None,
1216 distance_responses: None,
1217 })
1218 }
1219
1220 async fn load_binary_file(path: &std::path::Path) -> crate::Result<Self> {
1222 tracing::info!("Loading binary HRTF file: {:?}", path);
1223
1224 let data = tokio::fs::read(path)
1225 .await
1226 .map_err(|e| crate::Error::hrtf(&format!("Failed to read binary file: {e}")))?;
1227
1228 if data.len() < 32 {
1229 return Err(crate::Error::hrtf(
1230 "Binary file too small to contain valid HRTF data",
1231 ));
1232 }
1233
1234 let mut cursor = 0;
1235
1236 let magic = &data[cursor..cursor + 4];
1239 if magic != b"HRTF" {
1240 return Err(crate::Error::hrtf("Invalid binary HRTF file format"));
1241 }
1242 cursor += 4;
1243
1244 let version = u32::from_le_bytes([
1246 data[cursor],
1247 data[cursor + 1],
1248 data[cursor + 2],
1249 data[cursor + 3],
1250 ]);
1251 cursor += 4;
1252 if version != 1 {
1253 return Err(crate::Error::hrtf(&format!(
1254 "Unsupported HRTF binary version: {version}"
1255 )));
1256 }
1257
1258 let sample_rate = u32::from_le_bytes([
1260 data[cursor],
1261 data[cursor + 1],
1262 data[cursor + 2],
1263 data[cursor + 3],
1264 ]);
1265 cursor += 4;
1266
1267 let hrir_length = u32::from_le_bytes([
1269 data[cursor],
1270 data[cursor + 1],
1271 data[cursor + 2],
1272 data[cursor + 3],
1273 ]) as usize;
1274 cursor += 4;
1275
1276 let measurement_count = u32::from_le_bytes([
1278 data[cursor],
1279 data[cursor + 1],
1280 data[cursor + 2],
1281 data[cursor + 3],
1282 ]) as usize;
1283 cursor += 4;
1284
1285 let distance = f32::from_le_bytes([
1287 data[cursor],
1288 data[cursor + 1],
1289 data[cursor + 2],
1290 data[cursor + 3],
1291 ]);
1292 cursor += 4;
1293
1294 let subject_id_len = u32::from_le_bytes([
1296 data[cursor],
1297 data[cursor + 1],
1298 data[cursor + 2],
1299 data[cursor + 3],
1300 ]) as usize;
1301 cursor += 4;
1302
1303 if cursor + subject_id_len > data.len() {
1304 return Err(crate::Error::hrtf(
1305 "Invalid subject ID length in binary file",
1306 ));
1307 }
1308
1309 let subject_id =
1310 String::from_utf8_lossy(&data[cursor..cursor + subject_id_len]).to_string();
1311 cursor += subject_id_len;
1312
1313 let mut left_responses = HashMap::new();
1315 let mut right_responses = HashMap::new();
1316 let mut azimuth_angles = Vec::new();
1317 let mut elevation_angles = Vec::new();
1318
1319 for _ in 0..measurement_count {
1320 if cursor + 8 + (hrir_length * 8) > data.len() {
1321 return Err(crate::Error::hrtf("Insufficient data for measurement"));
1322 }
1323
1324 let azimuth = i32::from_le_bytes([
1326 data[cursor],
1327 data[cursor + 1],
1328 data[cursor + 2],
1329 data[cursor + 3],
1330 ]);
1331 cursor += 4;
1332
1333 let elevation = i32::from_le_bytes([
1335 data[cursor],
1336 data[cursor + 1],
1337 data[cursor + 2],
1338 data[cursor + 3],
1339 ]);
1340 cursor += 4;
1341
1342 let mut left_hrir = Vec::with_capacity(hrir_length);
1344 for _ in 0..hrir_length {
1345 let sample = f32::from_le_bytes([
1346 data[cursor],
1347 data[cursor + 1],
1348 data[cursor + 2],
1349 data[cursor + 3],
1350 ]);
1351 left_hrir.push(sample);
1352 cursor += 4;
1353 }
1354
1355 let mut right_hrir = Vec::with_capacity(hrir_length);
1357 for _ in 0..hrir_length {
1358 let sample = f32::from_le_bytes([
1359 data[cursor],
1360 data[cursor + 1],
1361 data[cursor + 2],
1362 data[cursor + 3],
1363 ]);
1364 right_hrir.push(sample);
1365 cursor += 4;
1366 }
1367
1368 left_responses.insert((azimuth, elevation), Array1::from(left_hrir));
1369 right_responses.insert((azimuth, elevation), Array1::from(right_hrir));
1370
1371 if !azimuth_angles.contains(&azimuth) {
1372 azimuth_angles.push(azimuth);
1373 }
1374 if !elevation_angles.contains(&elevation) {
1375 elevation_angles.push(elevation);
1376 }
1377 }
1378
1379 azimuth_angles.sort();
1380 elevation_angles.sort();
1381
1382 if left_responses.is_empty() || right_responses.is_empty() {
1383 return Err(crate::Error::hrtf(
1384 "No valid HRTF measurements found in binary file",
1385 ));
1386 }
1387
1388 let metadata = HrtfMetadata {
1389 name: format!("Binary HRTF Database ({subject_id})"),
1390 sample_rate,
1391 hrir_length,
1392 azimuth_angles,
1393 elevation_angles,
1394 distances: Some(vec![distance]),
1395 subject_info: Some(SubjectInfo {
1396 head_circumference: 56.0,
1397 head_width: 15.0,
1398 head_height: 20.0,
1399 ear_height: 10.0,
1400 shoulder_width: 40.0,
1401 }),
1402 };
1403
1404 tracing::info!(
1405 "Successfully loaded {} HRTF measurements from binary file",
1406 left_responses.len()
1407 );
1408
1409 Ok(Self {
1410 metadata,
1411 left_responses,
1412 right_responses,
1413 frequency_responses: None,
1414 distance_responses: None,
1415 })
1416 }
1417
1418 async fn load_enhanced_default() -> crate::Result<Self> {
1420 let metadata = HrtfMetadata {
1421 name: "Enhanced Default HRTF".to_string(),
1422 sample_rate: 44100,
1423 hrir_length: 512, azimuth_angles: (-180..=180).step_by(5).collect(), elevation_angles: (-90..=90).step_by(5).collect(), distances: Some(vec![0.2, 0.5, 1.0, 2.0, 5.0]), subject_info: Some(SubjectInfo {
1428 head_circumference: 57.0,
1429 head_width: 15.5,
1430 head_height: 24.0,
1431 ear_height: 12.0,
1432 shoulder_width: 45.0,
1433 }),
1434 };
1435
1436 let mut left_responses = HashMap::new();
1437 let mut right_responses = HashMap::new();
1438 let mut distance_responses = HashMap::new();
1439
1440 for &azimuth in &metadata.azimuth_angles {
1442 for &elevation in &metadata.elevation_angles {
1443 let distances = metadata
1444 .distances
1445 .as_ref()
1446 .expect("distances must be provided in enhanced HRTF metadata");
1447 for &distance in distances {
1448 let (left_hrir, right_hrir) = Self::generate_enhanced_hrtf(
1449 azimuth,
1450 elevation,
1451 distance,
1452 metadata.hrir_length,
1453 );
1454
1455 if (distance - 1.0).abs() < 0.1 {
1457 left_responses.insert((azimuth, elevation), left_hrir.clone());
1458 right_responses.insert((azimuth, elevation), right_hrir.clone());
1459 }
1460
1461 let distance_key = (distance * 100.0) as u32; distance_responses
1464 .insert((azimuth, elevation, distance_key), (left_hrir, right_hrir));
1465 }
1466 }
1467 }
1468
1469 Ok(Self {
1470 metadata,
1471 left_responses,
1472 right_responses,
1473 frequency_responses: None,
1474 distance_responses: Some(distance_responses),
1475 })
1476 }
1477
1478 pub async fn load_default() -> crate::Result<Self> {
1480 let metadata = HrtfMetadata {
1482 name: "Default HRTF".to_string(),
1483 sample_rate: 44100,
1484 hrir_length: 256,
1485 azimuth_angles: (-180..=180).step_by(15).collect(),
1486 elevation_angles: (-90..=90).step_by(15).collect(),
1487 distances: None,
1488 subject_info: None,
1489 };
1490
1491 let mut left_responses = HashMap::new();
1492 let mut right_responses = HashMap::new();
1493
1494 for &azimuth in &metadata.azimuth_angles {
1496 for &elevation in &metadata.elevation_angles {
1497 let (left_hrir, right_hrir) =
1498 Self::generate_simple_hrtf(azimuth, elevation, metadata.hrir_length);
1499 left_responses.insert((azimuth, elevation), left_hrir);
1500 right_responses.insert((azimuth, elevation), right_hrir);
1501 }
1502 }
1503
1504 Ok(Self {
1505 metadata,
1506 left_responses,
1507 right_responses,
1508 frequency_responses: None,
1509 distance_responses: None,
1510 })
1511 }
1512
1513 fn generate_simple_hrtf(
1515 azimuth: i32,
1516 _elevation: i32,
1517 length: usize,
1518 ) -> (Array1<f32>, Array1<f32>) {
1519 let mut left_hrir = Array1::zeros(length);
1520 let mut right_hrir = Array1::zeros(length);
1521
1522 let _azimuth_rad = (azimuth as f32).to_radians();
1524
1525 let left_delay = if azimuth < 0 {
1527 0
1528 } else {
1529 (azimuth as f32 / 180.0 * 10.0) as usize
1530 };
1531 let left_gain = 1.0 - (azimuth as f32).abs() / 180.0 * 0.3;
1532
1533 let right_delay = if azimuth > 0 {
1535 0
1536 } else {
1537 ((-azimuth) as f32 / 180.0 * 10.0) as usize
1538 };
1539 let right_gain = 1.0 - (azimuth as f32).abs() / 180.0 * 0.3;
1540
1541 if left_delay < length {
1543 left_hrir[left_delay] = left_gain;
1544 }
1545 if right_delay < length {
1546 right_hrir[right_delay] = right_gain;
1547 }
1548
1549 (left_hrir, right_hrir)
1550 }
1551
1552 fn generate_enhanced_hrtf(
1554 azimuth: i32,
1555 elevation: i32,
1556 distance: f32,
1557 length: usize,
1558 ) -> (Array1<f32>, Array1<f32>) {
1559 let sample_rate = 44100.0; let mut left_hrir = Array1::zeros(length);
1561 let mut right_hrir = Array1::zeros(length);
1562
1563 let azimuth_rad = (azimuth as f32).to_radians();
1565 let elevation_rad = (elevation as f32).to_radians();
1566
1567 let head_radius = 0.09; let distance_attenuation = 1.0 / (distance + 0.01);
1572
1573 let itd = if azimuth_rad.abs() <= std::f32::consts::PI / 2.0 {
1575 (head_radius / 343.0) * (azimuth_rad + azimuth_rad.sin()) * sample_rate
1577 } else {
1578 (head_radius / 343.0) * (std::f32::consts::PI / 2.0 + azimuth_rad.sin()) * sample_rate
1580 };
1581
1582 let left_delay = if azimuth >= 0 {
1584 (itd / 2.0) as usize
1585 } else {
1586 0
1587 };
1588 let right_delay = if azimuth < 0 {
1589 (-itd / 2.0) as usize
1590 } else {
1591 0
1592 };
1593
1594 let frequency_factor = 1.0; let shadow_attenuation = if azimuth_rad.abs() > std::f32::consts::PI / 2.0 {
1597 0.3
1598 } else {
1599 0.0
1600 };
1601
1602 let left_gain = distance_attenuation
1603 * frequency_factor
1604 * (1.0 - if azimuth > 0 { shadow_attenuation } else { 0.0 });
1605 let right_gain = distance_attenuation
1606 * frequency_factor
1607 * (1.0 - if azimuth < 0 { shadow_attenuation } else { 0.0 });
1608
1609 let elevation_gain = (1.0 + 0.2 * elevation_rad.sin()).clamp(0.5, 1.5);
1611
1612 let near_field_boost = if distance < 0.5 {
1614 1.0 + (0.5 - distance) * 0.5
1615 } else {
1616 1.0
1617 };
1618
1619 let primary_delay = (distance / 343.0 * sample_rate) as usize;
1621
1622 if primary_delay + left_delay < length {
1624 left_hrir[primary_delay + left_delay] = left_gain * elevation_gain * near_field_boost;
1625 }
1626 if primary_delay + right_delay < length {
1627 right_hrir[primary_delay + right_delay] =
1628 right_gain * elevation_gain * near_field_boost;
1629 }
1630
1631 let reflection_delay = primary_delay + (0.002 * sample_rate) as usize;
1633 if reflection_delay < length {
1634 let reflection_gain = 0.1 * distance_attenuation;
1635 if reflection_delay + left_delay < length {
1636 left_hrir[reflection_delay + left_delay] += reflection_gain;
1637 }
1638 if reflection_delay + right_delay < length {
1639 right_hrir[reflection_delay + right_delay] += reflection_gain;
1640 }
1641 }
1642
1643 let window_size = (length / 8).min(32);
1645 for i in 0..window_size {
1646 let window_val =
1647 0.5 * (1.0 - ((i as f32) / (window_size as f32) * std::f32::consts::PI).cos());
1648 if i < left_hrir.len() {
1649 left_hrir[i] *= window_val;
1650 }
1651 if i < right_hrir.len() {
1652 right_hrir[i] *= window_val;
1653 }
1654 }
1655
1656 (left_hrir, right_hrir)
1657 }
1658
1659 pub fn metadata(&self) -> &HrtfMetadata {
1661 &self.metadata
1662 }
1663
1664 pub fn available_positions(&self) -> Vec<SphericalCoordinates> {
1666 let mut positions = Vec::new();
1667 for &azimuth in &self.metadata.azimuth_angles {
1668 for &elevation in &self.metadata.elevation_angles {
1669 positions.push(SphericalCoordinates {
1670 azimuth: azimuth as f32,
1671 elevation: elevation as f32,
1672 distance: 1.0, });
1674 }
1675 }
1676 positions
1677 }
1678}
1679
1680impl Default for HrtfConfig {
1681 fn default() -> Self {
1682 Self {
1683 sample_rate: 44100,
1684 hrir_length: 256,
1685 crossfade_time: 0.01, enable_distance_modeling: true,
1687 interpolation_method: InterpolationMethod::Bilinear,
1688 head_circumference: None,
1689 near_field_distance: 0.2, far_field_distance: 10.0, enable_air_absorption: true,
1692 temperature: 20.0, humidity: 0.5, enable_simd: true,
1695 }
1696 }
1697}
1698
1699#[cfg(test)]
1700mod tests {
1701 use super::*;
1702
1703 #[tokio::test]
1704 async fn test_hrtf_processor_creation() {
1705 let processor = HrtfProcessor::new(None).await;
1706 assert!(processor.is_ok());
1707 }
1708
1709 #[tokio::test]
1710 async fn test_hrtf_database_loading() {
1711 let database = HrtfDatabase::load_default().await;
1712 assert!(database.is_ok());
1713
1714 let db = database.unwrap();
1715 assert!(!db.left_responses.is_empty());
1716 assert!(!db.right_responses.is_empty());
1717 assert_eq!(db.left_responses.len(), db.right_responses.len());
1718 }
1719
1720 #[tokio::test]
1721 async fn test_cartesian_to_spherical() {
1722 let processor = HrtfProcessor::new(None).await.unwrap();
1723
1724 let pos = Position3D::new(1.0, 0.0, 0.0);
1726 let spherical = processor.cartesian_to_spherical(&pos);
1727 assert!((spherical.azimuth - 0.0).abs() < 0.1);
1728 assert!((spherical.elevation - 0.0).abs() < 0.1);
1729
1730 let pos = Position3D::new(0.0, 0.0, 1.0);
1732 let spherical = processor.cartesian_to_spherical(&pos);
1733 assert!((spherical.azimuth - 90.0).abs() < 0.1);
1734 }
1735
1736 #[tokio::test]
1737 async fn test_hrtf_processing() {
1738 let processor = HrtfProcessor::new(None).await.unwrap();
1739 let input = Array1::from_vec(vec![1.0, 0.5, -0.5, -1.0]);
1740 let mut left_output = Array1::zeros(input.len());
1741 let mut right_output = Array1::zeros(input.len());
1742
1743 let position = Position3D::new(1.0, 0.0, 0.0);
1744 let result = processor
1745 .process_position(&input, &mut left_output, &mut right_output, &position)
1746 .await;
1747 assert!(result.is_ok());
1748 }
1749
1750 #[tokio::test]
1751 async fn test_realtime_processing() {
1752 let mut processor = HrtfProcessor::new(None).await.unwrap();
1753 let chunk_size = 64;
1754 let input = Array1::from_vec(vec![0.1; chunk_size]);
1755 let mut left_output = Array1::zeros(chunk_size);
1756 let mut right_output = Array1::zeros(chunk_size);
1757
1758 let position = Position3D::new(1.0, 0.0, 0.0);
1759
1760 let result1 = processor
1762 .process_realtime_chunk(&input, &mut left_output, &mut right_output, &position)
1763 .await;
1764 assert!(result1.is_ok());
1765
1766 let result2 = processor
1768 .process_realtime_chunk(&input, &mut left_output, &mut right_output, &position)
1769 .await;
1770 assert!(result2.is_ok());
1771 }
1772
1773 #[tokio::test]
1774 async fn test_interpolation_methods() {
1775 let configs = [
1777 (InterpolationMethod::Nearest, "Nearest"),
1778 (InterpolationMethod::Bilinear, "Bilinear"),
1779 (InterpolationMethod::Spherical, "Spherical"),
1780 (InterpolationMethod::Weighted, "Weighted"),
1781 ];
1782
1783 for (method, name) in configs {
1784 let config = HrtfConfig {
1785 interpolation_method: method,
1786 ..Default::default()
1787 };
1788
1789 let processor = HrtfProcessor::with_config(None, config).await.unwrap();
1790 let coords = SphericalCoordinates {
1791 azimuth: 45.0,
1792 elevation: 15.0,
1793 distance: 1.0,
1794 };
1795
1796 let result = processor.get_hrtf(&coords);
1797 assert!(result.is_ok(), "Failed interpolation method: {}", name);
1798 }
1799 }
1800
1801 #[tokio::test]
1802 async fn test_crossfade_processing() {
1803 let processor = HrtfProcessor::new(None).await.unwrap();
1804 let input = Array1::from_vec(vec![1.0, 0.5, -0.5, -1.0]);
1805 let mut left_output = Array1::zeros(input.len());
1806 let mut right_output = Array1::zeros(input.len());
1807
1808 let positions = vec![
1809 (Position3D::new(1.0, 0.0, 0.0), 0.7), (Position3D::new(-1.0, 0.0, 0.0), 0.3), ];
1812
1813 let result = processor
1814 .process_crossfade(&input, &mut left_output, &mut right_output, &positions)
1815 .await;
1816 assert!(result.is_ok());
1817
1818 let left_sum: f32 = left_output.iter().map(|x| x.abs()).sum();
1820 let right_sum: f32 = right_output.iter().map(|x| x.abs()).sum();
1821 assert!(left_sum > 0.0);
1822 assert!(right_sum > 0.0);
1823 }
1824
1825 #[tokio::test]
1826 async fn test_frequency_domain_convolution() {
1827 let processor = HrtfProcessor::new(None).await.unwrap();
1828
1829 let input = Array1::from_vec(vec![0.1; 128]);
1831 let hrir_len = processor.config.hrir_length;
1832 let left_hrir = Array1::from_vec(vec![1.0; hrir_len]);
1833 let right_hrir = Array1::from_vec(vec![0.8; hrir_len]);
1834
1835 let mut left_output = Array1::zeros(input.len());
1836 let mut right_output = Array1::zeros(input.len());
1837
1838 let result = processor.convolve_hrtf(
1839 &input,
1840 &left_hrir,
1841 &right_hrir,
1842 &mut left_output,
1843 &mut right_output,
1844 );
1845 assert!(result.is_ok());
1846
1847 let left_energy: f32 = left_output.iter().map(|x| x * x).sum();
1849 let right_energy: f32 = right_output.iter().map(|x| x * x).sum();
1850 assert!(left_energy > 0.0);
1851 assert!(right_energy > 0.0);
1852 }
1853}