pub mod ai_personalization;
pub mod database;
use crate::types::Position3D;
pub use database::{
DatabaseConfig, DatabaseStatistics, HrtfDatabaseManager, HrtfMeasurement as DbHrtfMeasurement,
HrtfPosition, InterpolationMethod as DbInterpolationMethod, PersonalizedHrtf, StorageFormat,
};
use scirs2_core::ndarray::Array1;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::PathBuf;
use std::sync::Arc;
pub struct HrtfProcessor {
database: Arc<HrtfDatabase>,
#[allow(dead_code)]
buffer_size: usize,
#[allow(dead_code)]
overlap_left: Array1<f32>,
#[allow(dead_code)]
overlap_right: Array1<f32>,
config: HrtfConfig,
}
#[derive(Clone)]
pub struct HrtfDatabase {
metadata: HrtfMetadata,
left_responses: HashMap<(i32, i32), Array1<f32>>,
right_responses: HashMap<(i32, i32), Array1<f32>>,
#[allow(dead_code)]
#[allow(clippy::type_complexity)]
frequency_responses: Option<HashMap<(i32, i32), (Array1<f32>, Array1<f32>)>>,
#[allow(dead_code)]
#[allow(clippy::type_complexity)]
distance_responses: Option<HashMap<(i32, i32, u32), (Array1<f32>, Array1<f32>)>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HrtfConfig {
pub sample_rate: u32,
pub hrir_length: usize,
pub crossfade_time: f32,
pub enable_distance_modeling: bool,
pub interpolation_method: InterpolationMethod,
pub head_circumference: Option<f32>,
pub near_field_distance: f32,
pub far_field_distance: f32,
pub enable_air_absorption: bool,
pub temperature: f32,
pub humidity: f32,
pub enable_simd: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HrtfMetadata {
pub name: String,
pub sample_rate: u32,
pub hrir_length: usize,
pub azimuth_angles: Vec<i32>,
pub elevation_angles: Vec<i32>,
pub distances: Option<Vec<f32>>,
pub subject_info: Option<SubjectInfo>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SubjectInfo {
pub head_circumference: f32,
pub head_width: f32,
pub head_height: f32,
pub ear_height: f32,
pub shoulder_width: f32,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum InterpolationMethod {
Nearest,
Bilinear,
Spherical,
Weighted,
}
#[derive(Debug, Clone, Copy)]
pub struct SphericalCoordinates {
pub azimuth: f32,
pub elevation: f32,
pub distance: f32,
}
#[derive(Debug, Clone)]
pub struct HrtfMeasurement {
pub coordinates: SphericalCoordinates,
pub left_hrir: Array1<f32>,
pub right_hrir: Array1<f32>,
}
impl HrtfProcessor {
pub async fn new(database_path: Option<PathBuf>) -> crate::Result<Self> {
let database = if let Some(path) = database_path {
HrtfDatabase::load_from_file(&path).await?
} else {
HrtfDatabase::load_default().await?
};
let config = HrtfConfig::default();
let buffer_size = config.hrir_length * 2;
Ok(Self {
database: Arc::new(database),
buffer_size,
overlap_left: Array1::zeros(buffer_size),
overlap_right: Array1::zeros(buffer_size),
config,
})
}
pub async fn new_default() -> crate::Result<Self> {
Self::new(None).await
}
pub async fn with_config(
database_path: Option<PathBuf>,
config: HrtfConfig,
) -> crate::Result<Self> {
let database = if let Some(path) = database_path {
HrtfDatabase::load_from_file(&path).await?
} else {
HrtfDatabase::load_default().await?
};
let buffer_size = config.hrir_length * 2;
Ok(Self {
database: Arc::new(database),
buffer_size,
overlap_left: Array1::zeros(buffer_size),
overlap_right: Array1::zeros(buffer_size),
config,
})
}
pub async fn process_position(
&self,
input: &Array1<f32>,
left_output: &mut Array1<f32>,
right_output: &mut Array1<f32>,
position: &Position3D,
) -> crate::Result<()> {
let spherical = self.cartesian_to_spherical(position);
let (mut left_hrir, mut right_hrir) = self.get_hrtf(&spherical)?;
if self.config.enable_distance_modeling {
self.apply_distance_modeling(&mut left_hrir, &mut right_hrir, &spherical)?;
}
if self.config.enable_air_absorption && spherical.distance > self.config.near_field_distance
{
self.apply_air_absorption(&mut left_hrir, &mut right_hrir, spherical.distance)?;
}
self.convolve_hrtf(input, &left_hrir, &right_hrir, left_output, right_output)?;
Ok(())
}
pub async fn process_position_smooth(
&self,
input: &Array1<f32>,
left_output: &mut Array1<f32>,
right_output: &mut Array1<f32>,
start_position: &Position3D,
end_position: &Position3D,
progress: f32,
) -> crate::Result<()> {
let current_position = Position3D::new(
start_position.x * (1.0 - progress) + end_position.x * progress,
start_position.y * (1.0 - progress) + end_position.y * progress,
start_position.z * (1.0 - progress) + end_position.z * progress,
);
self.process_position(input, left_output, right_output, ¤t_position)
.await
}
pub async fn process_realtime_chunk(
&mut self,
input: &Array1<f32>,
left_output: &mut Array1<f32>,
right_output: &mut Array1<f32>,
position: &Position3D,
) -> crate::Result<()> {
let chunk_size = input.len();
let hrir_len = self.config.hrir_length;
if left_output.len() != chunk_size || right_output.len() != chunk_size {
return Err(crate::Error::processing(
"Output buffer size must match input chunk size",
));
}
let spherical = self.cartesian_to_spherical(position);
let (left_hrir, right_hrir) = self.get_hrtf(&spherical)?;
let conv_len = chunk_size + hrir_len - 1;
let mut left_conv = Array1::zeros(conv_len);
let mut right_conv = Array1::zeros(conv_len);
self.convolve_hrtf(
input,
&left_hrir,
&right_hrir,
&mut left_conv,
&mut right_conv,
)?;
for i in 0..chunk_size {
left_output[i] = left_conv[i] + self.overlap_left[i];
right_output[i] = right_conv[i] + self.overlap_right[i];
}
self.overlap_left.fill(0.0);
self.overlap_right.fill(0.0);
let tail_start = chunk_size;
let tail_len = (conv_len - chunk_size).min(self.overlap_left.len());
for i in 0..tail_len {
self.overlap_left[i] = left_conv[tail_start + i];
self.overlap_right[i] = right_conv[tail_start + i];
}
Ok(())
}
pub fn reset_buffers(&mut self) {
self.overlap_left.fill(0.0);
self.overlap_right.fill(0.0);
}
pub async fn process_crossfade(
&self,
input: &Array1<f32>,
left_output: &mut Array1<f32>,
right_output: &mut Array1<f32>,
positions: &[(Position3D, f32)], ) -> crate::Result<()> {
left_output.fill(0.0);
right_output.fill(0.0);
let mut temp_left = Array1::zeros(input.len());
let mut temp_right = Array1::zeros(input.len());
for (position, weight) in positions {
self.process_position(input, &mut temp_left, &mut temp_right, position)
.await?;
for i in 0..left_output.len() {
left_output[i] += temp_left[i] * weight;
right_output[i] += temp_right[i] * weight;
}
}
Ok(())
}
fn get_hrtf(&self, coords: &SphericalCoordinates) -> crate::Result<(Array1<f32>, Array1<f32>)> {
match self.config.interpolation_method {
InterpolationMethod::Nearest => self.get_nearest_hrtf(coords),
InterpolationMethod::Bilinear => self.get_bilinear_hrtf(coords),
InterpolationMethod::Spherical => self.get_spherical_hrtf(coords),
InterpolationMethod::Weighted => self.get_weighted_hrtf(coords),
}
}
fn get_nearest_hrtf(
&self,
coords: &SphericalCoordinates,
) -> crate::Result<(Array1<f32>, Array1<f32>)> {
let azimuth = coords.azimuth.round() as i32;
let elevation = coords.elevation.round() as i32;
let closest_azimuth =
self.find_closest_angle(azimuth, &self.database.metadata.azimuth_angles);
let closest_elevation =
self.find_closest_angle(elevation, &self.database.metadata.elevation_angles);
let key = (closest_azimuth, closest_elevation);
let left_hrir = self.database.left_responses.get(&key).ok_or_else(|| {
crate::Error::LegacyHrtf(format!(
"No HRTF found for angles ({closest_azimuth}, {closest_elevation})"
))
})?;
let right_hrir = self.database.right_responses.get(&key).ok_or_else(|| {
crate::Error::LegacyHrtf(format!(
"No HRTF found for angles ({closest_azimuth}, {closest_elevation})"
))
})?;
Ok((left_hrir.clone(), right_hrir.clone()))
}
fn get_bilinear_hrtf(
&self,
coords: &SphericalCoordinates,
) -> crate::Result<(Array1<f32>, Array1<f32>)> {
let az_low = self.find_lower_angle(
coords.azimuth as i32,
&self.database.metadata.azimuth_angles,
);
let az_high = self.find_higher_angle(
coords.azimuth as i32,
&self.database.metadata.azimuth_angles,
);
let el_low = self.find_lower_angle(
coords.elevation as i32,
&self.database.metadata.elevation_angles,
);
let el_high = self.find_higher_angle(
coords.elevation as i32,
&self.database.metadata.elevation_angles,
);
let hrtf_00 = self.get_hrtf_at_angles(az_low, el_low)?;
let hrtf_01 = self.get_hrtf_at_angles(az_low, el_high)?;
let hrtf_10 = self.get_hrtf_at_angles(az_high, el_low)?;
let hrtf_11 = self.get_hrtf_at_angles(az_high, el_high)?;
let az_weight = if az_high != az_low {
(coords.azimuth - az_low as f32) / (az_high - az_low) as f32
} else {
0.0
};
let el_weight = if el_high != el_low {
(coords.elevation - el_low as f32) / (el_high - el_low) as f32
} else {
0.0
};
let left_hrir = self.interpolate_hrtf(&[
(&hrtf_00.0, (1.0 - az_weight) * (1.0 - el_weight)),
(&hrtf_01.0, (1.0 - az_weight) * el_weight),
(&hrtf_10.0, az_weight * (1.0 - el_weight)),
(&hrtf_11.0, az_weight * el_weight),
]);
let right_hrir = self.interpolate_hrtf(&[
(&hrtf_00.1, (1.0 - az_weight) * (1.0 - el_weight)),
(&hrtf_01.1, (1.0 - az_weight) * el_weight),
(&hrtf_10.1, az_weight * (1.0 - el_weight)),
(&hrtf_11.1, az_weight * el_weight),
]);
Ok((left_hrir, right_hrir))
}
fn get_spherical_hrtf(
&self,
coords: &SphericalCoordinates,
) -> crate::Result<(Array1<f32>, Array1<f32>)> {
let mut left_sum = Array1::zeros(self.config.hrir_length);
let mut right_sum = Array1::zeros(self.config.hrir_length);
let mut weight_sum = 0.0;
let mut nearest_points = Vec::new();
for (&(az, el), left_hrir) in &self.database.left_responses {
let Some(right_hrir) = self.database.right_responses.get(&(az, el)) else {
continue;
};
let angular_distance = self.calculate_angular_distance(
coords.azimuth,
coords.elevation,
az as f32,
el as f32,
);
nearest_points.push((angular_distance, left_hrir, right_hrir));
}
nearest_points.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
nearest_points.truncate(4);
for (distance, left_hrir, right_hrir) in nearest_points {
let weight = if distance < 0.01 {
1.0 } else {
1.0 / (distance.to_radians().sin() + 0.001)
};
weight_sum += weight;
for i in 0..left_sum.len().min(left_hrir.len()) {
left_sum[i] += left_hrir[i] * weight;
right_sum[i] += right_hrir[i] * weight;
}
}
if weight_sum > 0.0 {
left_sum /= weight_sum;
right_sum /= weight_sum;
}
Ok((left_sum, right_sum))
}
fn get_weighted_hrtf(
&self,
coords: &SphericalCoordinates,
) -> crate::Result<(Array1<f32>, Array1<f32>)> {
let mut left_sum = Array1::zeros(self.config.hrir_length);
let mut right_sum = Array1::zeros(self.config.hrir_length);
let mut weight_sum = 0.0;
for (&(az, el), left_hrir) in &self.database.left_responses {
let Some(right_hrir) = self.database.right_responses.get(&(az, el)) else {
continue;
};
let angular_distance = self.calculate_angular_distance(
coords.azimuth,
coords.elevation,
az as f32,
el as f32,
);
if angular_distance < 30.0 {
let weight = 1.0 / (1.0 + angular_distance);
weight_sum += weight;
for i in 0..left_sum.len() {
left_sum[i] += left_hrir[i] * weight;
right_sum[i] += right_hrir[i] * weight;
}
}
}
if weight_sum > 0.0 {
left_sum /= weight_sum;
right_sum /= weight_sum;
}
Ok((left_sum, right_sum))
}
fn convolve_hrtf(
&self,
input: &Array1<f32>,
left_hrir: &Array1<f32>,
right_hrir: &Array1<f32>,
left_output: &mut Array1<f32>,
right_output: &mut Array1<f32>,
) -> crate::Result<()> {
let input_len = input.len();
let hrir_len = left_hrir.len();
if input_len < 64 || hrir_len < 64 {
return self.convolve_time_domain(
input,
left_hrir,
right_hrir,
left_output,
right_output,
);
}
self.convolve_frequency_domain(input, left_hrir, right_hrir, left_output, right_output)
}
fn convolve_time_domain(
&self,
input: &Array1<f32>,
left_hrir: &Array1<f32>,
right_hrir: &Array1<f32>,
left_output: &mut Array1<f32>,
right_output: &mut Array1<f32>,
) -> crate::Result<()> {
left_output.fill(0.0);
right_output.fill(0.0);
for (i, &sample) in input.iter().enumerate() {
for (j, &hrir_sample) in left_hrir.iter().enumerate() {
if i + j < left_output.len() {
left_output[i + j] += sample * hrir_sample;
}
}
for (j, &hrir_sample) in right_hrir.iter().enumerate() {
if i + j < right_output.len() {
right_output[i + j] += sample * hrir_sample;
}
}
}
Ok(())
}
fn convolve_frequency_domain(
&self,
input: &Array1<f32>,
left_hrir: &Array1<f32>,
right_hrir: &Array1<f32>,
left_output: &mut Array1<f32>,
right_output: &mut Array1<f32>,
) -> crate::Result<()> {
let input_len = input.len();
let hrir_len = left_hrir.len();
let conv_len = input_len + hrir_len - 1;
let fft_len = conv_len.next_power_of_two();
let input_complex: Vec<scirs2_core::Complex<f64>> = input
.iter()
.map(|&x| scirs2_core::Complex::new(x as f64, 0.0))
.chain(std::iter::repeat(scirs2_core::Complex::new(0.0, 0.0)))
.take(fft_len)
.collect();
let left_hrir_complex: Vec<scirs2_core::Complex<f64>> = left_hrir
.iter()
.map(|&x| scirs2_core::Complex::new(x as f64, 0.0))
.chain(std::iter::repeat(scirs2_core::Complex::new(0.0, 0.0)))
.take(fft_len)
.collect();
let right_hrir_complex: Vec<scirs2_core::Complex<f64>> = right_hrir
.iter()
.map(|&x| scirs2_core::Complex::new(x as f64, 0.0))
.chain(std::iter::repeat(scirs2_core::Complex::new(0.0, 0.0)))
.take(fft_len)
.collect();
let input_spectrum = scirs2_fft::fft(&input_complex, None)
.map_err(|e| crate::Error::LegacyProcessing(format!("FFT error: {e}")))?;
let left_hrir_spectrum = scirs2_fft::fft(&left_hrir_complex, None)
.map_err(|e| crate::Error::LegacyProcessing(format!("FFT error: {e}")))?;
let right_hrir_spectrum = scirs2_fft::fft(&right_hrir_complex, None)
.map_err(|e| crate::Error::LegacyProcessing(format!("FFT error: {e}")))?;
let left_result_spectrum: Vec<scirs2_core::Complex<f64>> = input_spectrum
.iter()
.zip(left_hrir_spectrum.iter())
.map(|(a, b)| a * b)
.collect();
let right_result_spectrum: Vec<scirs2_core::Complex<f64>> = input_spectrum
.iter()
.zip(right_hrir_spectrum.iter())
.map(|(a, b)| a * b)
.collect();
let left_result_time = scirs2_fft::ifft(&left_result_spectrum, None)
.map_err(|e| crate::Error::LegacyProcessing(format!("IFFT error: {e}")))?;
let right_result_time = scirs2_fft::ifft(&right_result_spectrum, None)
.map_err(|e| crate::Error::LegacyProcessing(format!("IFFT error: {e}")))?;
let output_len = left_output.len().min(conv_len);
for i in 0..output_len {
left_output[i] = left_result_time[i].re as f32;
right_output[i] = right_result_time[i].re as f32;
}
Ok(())
}
fn cartesian_to_spherical(&self, position: &Position3D) -> SphericalCoordinates {
let distance =
(position.x * position.x + position.y * position.y + position.z * position.z).sqrt();
let azimuth = if distance > 0.0 {
position.z.atan2(position.x).to_degrees()
} else {
0.0
};
let elevation = if distance > 0.0 {
(position.y / distance).asin().to_degrees()
} else {
0.0
};
SphericalCoordinates {
azimuth,
elevation,
distance,
}
}
fn find_closest_angle(&self, target: i32, available: &[i32]) -> i32 {
available
.iter()
.min_by_key(|&&angle| (angle - target).abs())
.copied()
.unwrap_or(0)
}
fn find_lower_angle(&self, target: i32, available: &[i32]) -> i32 {
available
.iter()
.filter(|&&angle| angle <= target)
.max()
.copied()
.unwrap_or(*available.first().unwrap_or(&0))
}
fn find_higher_angle(&self, target: i32, available: &[i32]) -> i32 {
available
.iter()
.filter(|&&angle| angle >= target)
.min()
.copied()
.unwrap_or(*available.last().unwrap_or(&0))
}
fn get_hrtf_at_angles(
&self,
azimuth: i32,
elevation: i32,
) -> crate::Result<(Array1<f32>, Array1<f32>)> {
let key = (azimuth, elevation);
let left = self.database.left_responses.get(&key).ok_or_else(|| {
crate::Error::LegacyHrtf(format!("No HRTF for angles ({azimuth}, {elevation})"))
})?;
let right = self.database.right_responses.get(&key).ok_or_else(|| {
crate::Error::LegacyHrtf(format!("No HRTF for angles ({azimuth}, {elevation})"))
})?;
Ok((left.clone(), right.clone()))
}
fn interpolate_hrtf(&self, weighted_hrirs: &[(&Array1<f32>, f32)]) -> Array1<f32> {
let mut result = Array1::zeros(self.config.hrir_length);
for (hrir, weight) in weighted_hrirs {
for i in 0..result.len().min(hrir.len()) {
result[i] += hrir[i] * weight;
}
}
result
}
fn calculate_angular_distance(&self, az1: f32, el1: f32, az2: f32, el2: f32) -> f32 {
let az1_rad = az1.to_radians();
let el1_rad = el1.to_radians();
let az2_rad = az2.to_radians();
let el2_rad = el2.to_radians();
let cos_distance = el1_rad.sin() * el2_rad.sin()
+ el1_rad.cos() * el2_rad.cos() * (az1_rad - az2_rad).cos();
cos_distance.clamp(-1.0, 1.0).acos().to_degrees()
}
fn apply_distance_modeling(
&self,
left_hrir: &mut Array1<f32>,
right_hrir: &mut Array1<f32>,
coords: &SphericalCoordinates,
) -> crate::Result<()> {
let distance = coords.distance.max(0.01);
let attenuation = 1.0 / distance;
let near_field_factor = if distance < self.config.near_field_distance {
let proximity_boost = (self.config.near_field_distance / distance).powf(0.3);
proximity_boost.min(3.0) } else {
1.0
};
let far_field_factor = if distance > self.config.far_field_distance {
0.8 + 0.2 * (self.config.far_field_distance / distance)
} else {
1.0
};
let total_gain = attenuation * near_field_factor * far_field_factor;
left_hrir.mapv_inplace(|x| x * total_gain);
right_hrir.mapv_inplace(|x| x * total_gain);
if distance < self.config.near_field_distance {
self.apply_proximity_delay(left_hrir, right_hrir, coords)?;
}
Ok(())
}
fn apply_air_absorption(
&self,
left_hrir: &mut Array1<f32>,
right_hrir: &mut Array1<f32>,
distance: f32,
) -> crate::Result<()> {
let temp_celsius = self.config.temperature;
let relative_humidity = self.config.humidity;
let temp_kelvin = temp_celsius + 273.15;
let temp_ratio = temp_kelvin / 293.15;
let h_rel = relative_humidity * (101.325 * temp_ratio.powf(-5.0241));
let distance_factor = (-distance / 100.0).exp(); let temp_factor = temp_ratio.powf(-0.1);
let humidity_factor = 1.0 - relative_humidity * 0.1;
let absorption_factor = distance_factor * temp_factor * humidity_factor;
for i in 0..left_hrir.len() {
let freq_weight = if i as f32 / left_hrir.len() as f32 > 0.5 {
absorption_factor.powf(1.0 + i as f32 / left_hrir.len() as f32)
} else {
absorption_factor
};
left_hrir[i] *= freq_weight;
right_hrir[i] *= freq_weight;
}
Ok(())
}
fn apply_proximity_delay(
&self,
left_hrir: &mut Array1<f32>,
right_hrir: &mut Array1<f32>,
coords: &SphericalCoordinates,
) -> crate::Result<()> {
let distance = coords.distance;
let azimuth_rad = coords.azimuth.to_radians();
let head_radius =
self.config.head_circumference.unwrap_or(57.0) / (2.0 * std::f32::consts::PI);
let sound_speed = 343.0;
let itd_samples = if distance < head_radius * 2.0 {
let enhanced_itd =
(head_radius * azimuth_rad.sin() * (1.0 + azimuth_rad.cos())) / sound_speed;
(enhanced_itd * self.config.sample_rate as f32) as usize
} else {
0
};
if itd_samples > 0 && azimuth_rad.abs() > 0.1 {
let delay_samples = itd_samples.min(left_hrir.len() / 4);
if azimuth_rad > 0.0 {
self.apply_delay(left_hrir, delay_samples);
} else {
self.apply_delay(right_hrir, delay_samples);
}
}
Ok(())
}
fn apply_delay(&self, hrir: &mut Array1<f32>, delay_samples: usize) {
if delay_samples == 0 || delay_samples >= hrir.len() {
return;
}
let original = hrir.clone();
hrir.fill(0.0);
for i in 0..(hrir.len() - delay_samples) {
hrir[i + delay_samples] = original[i];
}
}
}
impl HrtfDatabase {
pub async fn load_from_file(path: &std::path::Path) -> crate::Result<Self> {
let extension = path
.extension()
.and_then(|ext| ext.to_str())
.map(|ext| ext.to_lowercase());
match extension.as_deref() {
Some("sofa") => Self::load_sofa_file(path).await,
Some("json") => Self::load_json_file(path).await,
Some("bin") | Some("hrtf") => Self::load_binary_file(path).await,
_ => {
tracing::warn!("Unknown HRTF file format, using default database");
Self::load_default().await
}
}
}
async fn load_sofa_file(path: &std::path::Path) -> crate::Result<Self> {
tracing::info!("Loading SOFA HRTF file: {:?}", path);
let content = tokio::fs::read_to_string(path)
.await
.map_err(|e| crate::Error::hrtf(&format!("Failed to read SOFA file: {e}")))?;
let mut metadata = HrtfMetadata {
name: "SOFA HRTF Database".to_string(),
sample_rate: 44100,
hrir_length: 512,
azimuth_angles: Vec::new(),
elevation_angles: Vec::new(),
distances: Some(vec![1.0]),
subject_info: Some(SubjectInfo {
head_circumference: 56.0,
head_width: 15.0,
head_height: 20.0,
ear_height: 10.0,
shoulder_width: 40.0,
}),
};
let mut left_responses = HashMap::new();
let mut right_responses = HashMap::new();
let mut current_section = "";
let mut current_measurement: Option<(i32, i32)> = None;
let mut left_hrir_data = Vec::new();
let mut right_hrir_data = Vec::new();
for line in content.lines() {
let line = line.trim();
if line.is_empty() || line.starts_with('#') {
continue;
}
if line.starts_with('[') && line.ends_with(']') {
current_section = &line[1..line.len() - 1];
continue;
}
match current_section {
"GLOBAL" => {
if let Some((key, value)) = line.split_once('=') {
let key = key.trim();
let value = value.trim();
match key {
"Data.SamplingRate" => {
if let Ok(rate) = value.parse::<u32>() {
metadata.sample_rate = rate;
}
}
"Data.IRLength" => {
if let Ok(length) = value.parse::<usize>() {
metadata.hrir_length = length;
}
}
"GLOBAL:DatabaseName" => {
metadata.name = value.to_string();
}
"GLOBAL:ListenerShortName" => {
metadata.subject_info = Some(SubjectInfo {
head_circumference: 56.0,
head_width: 15.0,
head_height: 20.0,
ear_height: 10.0,
shoulder_width: 40.0,
});
}
_ => {}
}
}
}
"POSITION" if line.starts_with("SourcePosition") => {
if let Some((_, coords)) = line.split_once('=') {
let parts: Vec<&str> = coords.split(',').collect();
if parts.len() >= 3 {
if let (Ok(azimuth), Ok(elevation), Ok(distance)) = (
parts[0].trim().parse::<f32>(),
parts[1].trim().parse::<f32>(),
parts[2].trim().parse::<f32>(),
) {
current_measurement = Some((azimuth as i32, elevation as i32));
metadata.distances = Some(vec![distance]);
}
}
}
}
"DATA_IR" => {
if let Some((azimuth, elevation)) = current_measurement {
if let Some(data_str) = line.strip_prefix("L:") {
left_hrir_data = data_str
.split_whitespace()
.filter_map(|s| s.parse::<f32>().ok())
.collect();
} else if let Some(data_str) = line.strip_prefix("R:") {
right_hrir_data = data_str
.split_whitespace()
.filter_map(|s| s.parse::<f32>().ok())
.collect();
}
if !left_hrir_data.is_empty() && !right_hrir_data.is_empty() {
left_responses
.insert((azimuth, elevation), Array1::from(left_hrir_data.clone()));
right_responses.insert(
(azimuth, elevation),
Array1::from(right_hrir_data.clone()),
);
if !metadata.azimuth_angles.contains(&azimuth) {
metadata.azimuth_angles.push(azimuth);
}
if !metadata.elevation_angles.contains(&elevation) {
metadata.elevation_angles.push(elevation);
}
left_hrir_data.clear();
right_hrir_data.clear();
current_measurement = None;
}
}
}
_ => {}
}
}
metadata.azimuth_angles.sort();
metadata.elevation_angles.sort();
if left_responses.is_empty() || right_responses.is_empty() {
tracing::warn!("No valid HRTF measurements found in SOFA file, using enhanced default");
return Self::load_enhanced_default().await;
}
tracing::info!(
"Successfully loaded {} HRTF measurements from SOFA file",
left_responses.len()
);
Ok(Self {
metadata,
left_responses,
right_responses,
frequency_responses: None,
distance_responses: None,
})
}
async fn load_json_file(path: &std::path::Path) -> crate::Result<Self> {
tracing::info!("Loading JSON HRTF file: {:?}", path);
let content = tokio::fs::read_to_string(path)
.await
.map_err(|e| crate::Error::hrtf(&format!("Failed to read JSON file: {e}")))?;
let json_data: serde_json::Value = serde_json::from_str(&content)
.map_err(|e| crate::Error::hrtf(&format!("Failed to parse JSON: {e}")))?;
let metadata = if let Some(meta) = json_data.get("metadata") {
HrtfMetadata {
name: meta
.get("name")
.and_then(|v| v.as_str())
.unwrap_or("JSON HRTF Database")
.to_string(),
sample_rate: meta
.get("sample_rate")
.and_then(|v| v.as_u64())
.unwrap_or(44100) as u32,
hrir_length: meta
.get("hrir_length")
.and_then(|v| v.as_u64())
.unwrap_or(512) as usize,
azimuth_angles: meta
.get("azimuth_angles")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|x| x.as_i64().map(|i| i as i32))
.collect()
})
.unwrap_or_else(|| (-180..=180).step_by(15).collect()),
elevation_angles: meta
.get("elevation_angles")
.and_then(|v| v.as_array())
.map(|arr| {
arr.iter()
.filter_map(|x| x.as_i64().map(|i| i as i32))
.collect()
})
.unwrap_or_else(|| (-40..=90).step_by(10).collect()),
distances: meta
.get("distance")
.and_then(|v| v.as_f64())
.map(|d| vec![d as f32]),
subject_info: meta.get("subject_id").and_then(|v| v.as_str()).map(|_| {
SubjectInfo {
head_circumference: 56.0,
head_width: 15.0,
head_height: 20.0,
ear_height: 10.0,
shoulder_width: 40.0,
}
}),
}
} else {
return Self::load_enhanced_default().await;
};
let mut left_responses = HashMap::new();
let mut right_responses = HashMap::new();
if let Some(measurements) = json_data.get("measurements").and_then(|v| v.as_array()) {
for measurement in measurements {
let azimuth = measurement
.get("azimuth")
.and_then(|v| v.as_i64())
.unwrap_or(0) as i32;
let elevation = measurement
.get("elevation")
.and_then(|v| v.as_i64())
.unwrap_or(0) as i32;
if let Some(left_hrir) = measurement.get("left_hrir").and_then(|v| v.as_array()) {
let left_data: Vec<f32> = left_hrir
.iter()
.filter_map(|x| x.as_f64().map(|f| f as f32))
.collect();
if !left_data.is_empty() {
left_responses.insert((azimuth, elevation), Array1::from(left_data));
}
}
if let Some(right_hrir) = measurement.get("right_hrir").and_then(|v| v.as_array()) {
let right_data: Vec<f32> = right_hrir
.iter()
.filter_map(|x| x.as_f64().map(|f| f as f32))
.collect();
if !right_data.is_empty() {
right_responses.insert((azimuth, elevation), Array1::from(right_data));
}
}
}
}
if left_responses.is_empty() || right_responses.is_empty() {
tracing::warn!("No valid HRTF measurements found in JSON file, using enhanced default");
return Self::load_enhanced_default().await;
}
tracing::info!(
"Successfully loaded {} HRTF measurements from JSON",
left_responses.len()
);
Ok(Self {
metadata,
left_responses,
right_responses,
frequency_responses: None,
distance_responses: None,
})
}
async fn load_binary_file(path: &std::path::Path) -> crate::Result<Self> {
tracing::info!("Loading binary HRTF file: {:?}", path);
let data = tokio::fs::read(path)
.await
.map_err(|e| crate::Error::hrtf(&format!("Failed to read binary file: {e}")))?;
if data.len() < 32 {
return Err(crate::Error::hrtf(
"Binary file too small to contain valid HRTF data",
));
}
let mut cursor = 0;
let magic = &data[cursor..cursor + 4];
if magic != b"HRTF" {
return Err(crate::Error::hrtf("Invalid binary HRTF file format"));
}
cursor += 4;
let version = u32::from_le_bytes([
data[cursor],
data[cursor + 1],
data[cursor + 2],
data[cursor + 3],
]);
cursor += 4;
if version != 1 {
return Err(crate::Error::hrtf(&format!(
"Unsupported HRTF binary version: {version}"
)));
}
let sample_rate = u32::from_le_bytes([
data[cursor],
data[cursor + 1],
data[cursor + 2],
data[cursor + 3],
]);
cursor += 4;
let hrir_length = u32::from_le_bytes([
data[cursor],
data[cursor + 1],
data[cursor + 2],
data[cursor + 3],
]) as usize;
cursor += 4;
let measurement_count = u32::from_le_bytes([
data[cursor],
data[cursor + 1],
data[cursor + 2],
data[cursor + 3],
]) as usize;
cursor += 4;
let distance = f32::from_le_bytes([
data[cursor],
data[cursor + 1],
data[cursor + 2],
data[cursor + 3],
]);
cursor += 4;
let subject_id_len = u32::from_le_bytes([
data[cursor],
data[cursor + 1],
data[cursor + 2],
data[cursor + 3],
]) as usize;
cursor += 4;
if cursor + subject_id_len > data.len() {
return Err(crate::Error::hrtf(
"Invalid subject ID length in binary file",
));
}
let subject_id =
String::from_utf8_lossy(&data[cursor..cursor + subject_id_len]).to_string();
cursor += subject_id_len;
let mut left_responses = HashMap::new();
let mut right_responses = HashMap::new();
let mut azimuth_angles = Vec::new();
let mut elevation_angles = Vec::new();
for _ in 0..measurement_count {
if cursor + 8 + (hrir_length * 8) > data.len() {
return Err(crate::Error::hrtf("Insufficient data for measurement"));
}
let azimuth = i32::from_le_bytes([
data[cursor],
data[cursor + 1],
data[cursor + 2],
data[cursor + 3],
]);
cursor += 4;
let elevation = i32::from_le_bytes([
data[cursor],
data[cursor + 1],
data[cursor + 2],
data[cursor + 3],
]);
cursor += 4;
let mut left_hrir = Vec::with_capacity(hrir_length);
for _ in 0..hrir_length {
let sample = f32::from_le_bytes([
data[cursor],
data[cursor + 1],
data[cursor + 2],
data[cursor + 3],
]);
left_hrir.push(sample);
cursor += 4;
}
let mut right_hrir = Vec::with_capacity(hrir_length);
for _ in 0..hrir_length {
let sample = f32::from_le_bytes([
data[cursor],
data[cursor + 1],
data[cursor + 2],
data[cursor + 3],
]);
right_hrir.push(sample);
cursor += 4;
}
left_responses.insert((azimuth, elevation), Array1::from(left_hrir));
right_responses.insert((azimuth, elevation), Array1::from(right_hrir));
if !azimuth_angles.contains(&azimuth) {
azimuth_angles.push(azimuth);
}
if !elevation_angles.contains(&elevation) {
elevation_angles.push(elevation);
}
}
azimuth_angles.sort();
elevation_angles.sort();
if left_responses.is_empty() || right_responses.is_empty() {
return Err(crate::Error::hrtf(
"No valid HRTF measurements found in binary file",
));
}
let metadata = HrtfMetadata {
name: format!("Binary HRTF Database ({subject_id})"),
sample_rate,
hrir_length,
azimuth_angles,
elevation_angles,
distances: Some(vec![distance]),
subject_info: Some(SubjectInfo {
head_circumference: 56.0,
head_width: 15.0,
head_height: 20.0,
ear_height: 10.0,
shoulder_width: 40.0,
}),
};
tracing::info!(
"Successfully loaded {} HRTF measurements from binary file",
left_responses.len()
);
Ok(Self {
metadata,
left_responses,
right_responses,
frequency_responses: None,
distance_responses: None,
})
}
async fn load_enhanced_default() -> crate::Result<Self> {
let metadata = HrtfMetadata {
name: "Enhanced Default HRTF".to_string(),
sample_rate: 44100,
hrir_length: 512, azimuth_angles: (-180..=180).step_by(5).collect(), elevation_angles: (-90..=90).step_by(5).collect(), distances: Some(vec![0.2, 0.5, 1.0, 2.0, 5.0]), subject_info: Some(SubjectInfo {
head_circumference: 57.0,
head_width: 15.5,
head_height: 24.0,
ear_height: 12.0,
shoulder_width: 45.0,
}),
};
let mut left_responses = HashMap::new();
let mut right_responses = HashMap::new();
let mut distance_responses = HashMap::new();
for &azimuth in &metadata.azimuth_angles {
for &elevation in &metadata.elevation_angles {
let distances = metadata
.distances
.as_ref()
.expect("distances must be provided in enhanced HRTF metadata");
for &distance in distances {
let (left_hrir, right_hrir) = Self::generate_enhanced_hrtf(
azimuth,
elevation,
distance,
metadata.hrir_length,
);
if (distance - 1.0).abs() < 0.1 {
left_responses.insert((azimuth, elevation), left_hrir.clone());
right_responses.insert((azimuth, elevation), right_hrir.clone());
}
let distance_key = (distance * 100.0) as u32; distance_responses
.insert((azimuth, elevation, distance_key), (left_hrir, right_hrir));
}
}
}
Ok(Self {
metadata,
left_responses,
right_responses,
frequency_responses: None,
distance_responses: Some(distance_responses),
})
}
pub async fn load_default() -> crate::Result<Self> {
let metadata = HrtfMetadata {
name: "Default HRTF".to_string(),
sample_rate: 44100,
hrir_length: 256,
azimuth_angles: (-180..=180).step_by(15).collect(),
elevation_angles: (-90..=90).step_by(15).collect(),
distances: None,
subject_info: None,
};
let mut left_responses = HashMap::new();
let mut right_responses = HashMap::new();
for &azimuth in &metadata.azimuth_angles {
for &elevation in &metadata.elevation_angles {
let (left_hrir, right_hrir) =
Self::generate_simple_hrtf(azimuth, elevation, metadata.hrir_length);
left_responses.insert((azimuth, elevation), left_hrir);
right_responses.insert((azimuth, elevation), right_hrir);
}
}
Ok(Self {
metadata,
left_responses,
right_responses,
frequency_responses: None,
distance_responses: None,
})
}
fn generate_simple_hrtf(
azimuth: i32,
_elevation: i32,
length: usize,
) -> (Array1<f32>, Array1<f32>) {
let mut left_hrir = Array1::zeros(length);
let mut right_hrir = Array1::zeros(length);
let _azimuth_rad = (azimuth as f32).to_radians();
let left_delay = if azimuth < 0 {
0
} else {
(azimuth as f32 / 180.0 * 10.0) as usize
};
let left_gain = 1.0 - (azimuth as f32).abs() / 180.0 * 0.3;
let right_delay = if azimuth > 0 {
0
} else {
((-azimuth) as f32 / 180.0 * 10.0) as usize
};
let right_gain = 1.0 - (azimuth as f32).abs() / 180.0 * 0.3;
if left_delay < length {
left_hrir[left_delay] = left_gain;
}
if right_delay < length {
right_hrir[right_delay] = right_gain;
}
(left_hrir, right_hrir)
}
fn generate_enhanced_hrtf(
azimuth: i32,
elevation: i32,
distance: f32,
length: usize,
) -> (Array1<f32>, Array1<f32>) {
let sample_rate = 44100.0; let mut left_hrir = Array1::zeros(length);
let mut right_hrir = Array1::zeros(length);
let azimuth_rad = (azimuth as f32).to_radians();
let elevation_rad = (elevation as f32).to_radians();
let head_radius = 0.09;
let distance_attenuation = 1.0 / (distance + 0.01);
let itd = if azimuth_rad.abs() <= std::f32::consts::PI / 2.0 {
(head_radius / 343.0) * (azimuth_rad + azimuth_rad.sin()) * sample_rate
} else {
(head_radius / 343.0) * (std::f32::consts::PI / 2.0 + azimuth_rad.sin()) * sample_rate
};
let left_delay = if azimuth >= 0 {
(itd / 2.0) as usize
} else {
0
};
let right_delay = if azimuth < 0 {
(-itd / 2.0) as usize
} else {
0
};
let frequency_factor = 1.0; let shadow_attenuation = if azimuth_rad.abs() > std::f32::consts::PI / 2.0 {
0.3
} else {
0.0
};
let left_gain = distance_attenuation
* frequency_factor
* (1.0 - if azimuth > 0 { shadow_attenuation } else { 0.0 });
let right_gain = distance_attenuation
* frequency_factor
* (1.0 - if azimuth < 0 { shadow_attenuation } else { 0.0 });
let elevation_gain = (1.0 + 0.2 * elevation_rad.sin()).clamp(0.5, 1.5);
let near_field_boost = if distance < 0.5 {
1.0 + (0.5 - distance) * 0.5
} else {
1.0
};
let primary_delay = (distance / 343.0 * sample_rate) as usize;
if primary_delay + left_delay < length {
left_hrir[primary_delay + left_delay] = left_gain * elevation_gain * near_field_boost;
}
if primary_delay + right_delay < length {
right_hrir[primary_delay + right_delay] =
right_gain * elevation_gain * near_field_boost;
}
let reflection_delay = primary_delay + (0.002 * sample_rate) as usize;
if reflection_delay < length {
let reflection_gain = 0.1 * distance_attenuation;
if reflection_delay + left_delay < length {
left_hrir[reflection_delay + left_delay] += reflection_gain;
}
if reflection_delay + right_delay < length {
right_hrir[reflection_delay + right_delay] += reflection_gain;
}
}
let window_size = (length / 8).min(32);
for i in 0..window_size {
let window_val =
0.5 * (1.0 - ((i as f32) / (window_size as f32) * std::f32::consts::PI).cos());
if i < left_hrir.len() {
left_hrir[i] *= window_val;
}
if i < right_hrir.len() {
right_hrir[i] *= window_val;
}
}
(left_hrir, right_hrir)
}
pub fn metadata(&self) -> &HrtfMetadata {
&self.metadata
}
pub fn available_positions(&self) -> Vec<SphericalCoordinates> {
let mut positions = Vec::new();
for &azimuth in &self.metadata.azimuth_angles {
for &elevation in &self.metadata.elevation_angles {
positions.push(SphericalCoordinates {
azimuth: azimuth as f32,
elevation: elevation as f32,
distance: 1.0, });
}
}
positions
}
}
impl Default for HrtfConfig {
fn default() -> Self {
Self {
sample_rate: 44100,
hrir_length: 256,
crossfade_time: 0.01, enable_distance_modeling: true,
interpolation_method: InterpolationMethod::Bilinear,
head_circumference: None,
near_field_distance: 0.2, far_field_distance: 10.0, enable_air_absorption: true,
temperature: 20.0, humidity: 0.5, enable_simd: true,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_hrtf_processor_creation() {
let processor = HrtfProcessor::new(None).await;
assert!(processor.is_ok());
}
#[tokio::test]
async fn test_hrtf_database_loading() {
let database = HrtfDatabase::load_default().await;
assert!(database.is_ok());
let db = database.unwrap();
assert!(!db.left_responses.is_empty());
assert!(!db.right_responses.is_empty());
assert_eq!(db.left_responses.len(), db.right_responses.len());
}
#[tokio::test]
async fn test_cartesian_to_spherical() {
let processor = HrtfProcessor::new(None).await.unwrap();
let pos = Position3D::new(1.0, 0.0, 0.0);
let spherical = processor.cartesian_to_spherical(&pos);
assert!((spherical.azimuth - 0.0).abs() < 0.1);
assert!((spherical.elevation - 0.0).abs() < 0.1);
let pos = Position3D::new(0.0, 0.0, 1.0);
let spherical = processor.cartesian_to_spherical(&pos);
assert!((spherical.azimuth - 90.0).abs() < 0.1);
}
#[tokio::test]
async fn test_hrtf_processing() {
let processor = HrtfProcessor::new(None).await.unwrap();
let input = Array1::from_vec(vec![1.0, 0.5, -0.5, -1.0]);
let mut left_output = Array1::zeros(input.len());
let mut right_output = Array1::zeros(input.len());
let position = Position3D::new(1.0, 0.0, 0.0);
let result = processor
.process_position(&input, &mut left_output, &mut right_output, &position)
.await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_realtime_processing() {
let mut processor = HrtfProcessor::new(None).await.unwrap();
let chunk_size = 64;
let input = Array1::from_vec(vec![0.1; chunk_size]);
let mut left_output = Array1::zeros(chunk_size);
let mut right_output = Array1::zeros(chunk_size);
let position = Position3D::new(1.0, 0.0, 0.0);
let result1 = processor
.process_realtime_chunk(&input, &mut left_output, &mut right_output, &position)
.await;
assert!(result1.is_ok());
let result2 = processor
.process_realtime_chunk(&input, &mut left_output, &mut right_output, &position)
.await;
assert!(result2.is_ok());
}
#[tokio::test]
async fn test_interpolation_methods() {
let configs = [
(InterpolationMethod::Nearest, "Nearest"),
(InterpolationMethod::Bilinear, "Bilinear"),
(InterpolationMethod::Spherical, "Spherical"),
(InterpolationMethod::Weighted, "Weighted"),
];
for (method, name) in configs {
let config = HrtfConfig {
interpolation_method: method,
..Default::default()
};
let processor = HrtfProcessor::with_config(None, config).await.unwrap();
let coords = SphericalCoordinates {
azimuth: 45.0,
elevation: 15.0,
distance: 1.0,
};
let result = processor.get_hrtf(&coords);
assert!(result.is_ok(), "Failed interpolation method: {}", name);
}
}
#[tokio::test]
async fn test_crossfade_processing() {
let processor = HrtfProcessor::new(None).await.unwrap();
let input = Array1::from_vec(vec![1.0, 0.5, -0.5, -1.0]);
let mut left_output = Array1::zeros(input.len());
let mut right_output = Array1::zeros(input.len());
let positions = vec![
(Position3D::new(1.0, 0.0, 0.0), 0.7), (Position3D::new(-1.0, 0.0, 0.0), 0.3), ];
let result = processor
.process_crossfade(&input, &mut left_output, &mut right_output, &positions)
.await;
assert!(result.is_ok());
let left_sum: f32 = left_output.iter().map(|x| x.abs()).sum();
let right_sum: f32 = right_output.iter().map(|x| x.abs()).sum();
assert!(left_sum > 0.0);
assert!(right_sum > 0.0);
}
#[tokio::test]
async fn test_frequency_domain_convolution() {
let processor = HrtfProcessor::new(None).await.unwrap();
let input = Array1::from_vec(vec![0.1; 128]);
let hrir_len = processor.config.hrir_length;
let left_hrir = Array1::from_vec(vec![1.0; hrir_len]);
let right_hrir = Array1::from_vec(vec![0.8; hrir_len]);
let mut left_output = Array1::zeros(input.len());
let mut right_output = Array1::zeros(input.len());
let result = processor.convolve_hrtf(
&input,
&left_hrir,
&right_hrir,
&mut left_output,
&mut right_output,
);
assert!(result.is_ok());
let left_energy: f32 = left_output.iter().map(|x| x * x).sum();
let right_energy: f32 = right_output.iter().map(|x| x * x).sum();
assert!(left_energy > 0.0);
assert!(right_energy > 0.0);
}
}