use crate::types::Position3D;
use crate::{Error, Result};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::sync::{Arc, RwLock};
pub struct HrtfDatabaseManager {
main_database: Arc<RwLock<HrtfDatabase>>,
personalized_cache: Arc<RwLock<HashMap<String, PersonalizedHrtf>>>,
config: DatabaseConfig,
metrics: DatabaseMetrics,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HrtfDatabase {
pub metadata: DatabaseMetadata,
pub measurements: HashMap<HrtfPosition, HrtfMeasurement>,
interpolation_cache: HashMap<HrtfPosition, Vec<InterpolationWeight>>,
distance_hrtfs: HashMap<u32, HashMap<HrtfPosition, HrtfMeasurement>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DatabaseMetadata {
pub name: String,
pub version: String,
pub sample_rate: u32,
pub hrtf_length: usize,
pub conditions: MeasurementConditions,
pub subject_demographics: SubjectDemographics,
pub created: String,
pub modified: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MeasurementConditions {
pub room: RoomCharacteristics,
pub equipment: EquipmentInfo,
pub methodology: String,
pub quality_metrics: QualityMetrics,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RoomCharacteristics {
pub room_type: String,
pub dimensions: (f32, f32, f32),
pub rt60: f32,
pub noise_floor: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EquipmentInfo {
pub microphone: String,
pub speaker: String,
pub audio_interface: String,
pub software: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QualityMetrics {
pub snr: f32,
pub thd: f32,
pub frequency_deviation: f32,
pub phase_coherence: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SubjectDemographics {
pub age: u32,
pub gender: String,
pub head_measurements: HeadMeasurements,
pub hearing_assessment: HearingAssessment,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HeadMeasurements {
pub width: f32,
pub depth: f32,
pub circumference: f32,
pub interaural_distance: f32,
pub pinna_left: PinnaMeasurements,
pub pinna_right: PinnaMeasurements,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PinnaMeasurements {
pub height: f32,
pub width: f32,
pub depth: f32,
pub concha_depth: f32,
pub concha_volume: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HearingAssessment {
pub audiogram: HashMap<u32, f32>,
pub status: String,
pub hearing_aid: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub struct HrtfPosition {
pub azimuth: i16,
pub elevation: i16,
pub distance_cm: u16,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HrtfMeasurement {
pub left_ir: Vec<f32>,
pub right_ir: Vec<f32>,
pub quality_score: f32,
pub itd_samples: f32,
pub ild_db: f32,
pub frequency_response: FrequencyResponse,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FrequencyResponse {
pub frequencies: Vec<f32>,
pub left_magnitude: Vec<f32>,
pub right_magnitude: Vec<f32>,
pub left_phase: Vec<f32>,
pub right_phase: Vec<f32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InterpolationWeight {
pub position: HrtfPosition,
pub weight: f32,
pub distance: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PersonalizedHrtf {
pub user_id: String,
pub head_measurements: HeadMeasurements,
pub measurements: HashMap<HrtfPosition, HrtfMeasurement>,
pub adaptation_params: AdaptationParameters,
pub last_updated: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AdaptationParameters {
pub head_scaling: f32,
pub pinna_scaling: f32,
pub itd_adjustment: f32,
pub ild_adjustment: f32,
pub frequency_adjustments: Vec<FrequencyAdjustment>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FrequencyAdjustment {
pub frequency: f32,
pub gain_db: f32,
pub q_factor: f32,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DatabaseConfig {
pub cache_size: usize,
pub interpolation_method: InterpolationMethod,
pub distance_interpolation: bool,
pub precompute_weights: bool,
pub compression: bool,
pub storage_format: StorageFormat,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum InterpolationMethod {
NearestNeighbor,
Bilinear,
SphericalSpline,
Barycentric,
RadialBasisFunction,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum StorageFormat {
Sofa,
Json,
Binary,
Hdf5,
}
#[derive(Debug, Clone, Default)]
pub struct DatabaseMetrics {
pub total_lookups: u64,
pub cache_hits: u64,
pub interpolations: u64,
pub avg_lookup_time_us: f64,
pub memory_usage_bytes: u64,
}
impl HrtfDatabaseManager {
pub fn new(config: DatabaseConfig) -> Result<Self> {
let main_database = Arc::new(RwLock::new(HrtfDatabase::new()));
let personalized_cache = Arc::new(RwLock::new(HashMap::new()));
let metrics = DatabaseMetrics::default();
Ok(Self {
main_database,
personalized_cache,
config,
metrics,
})
}
pub async fn load_database(&mut self, path: &Path) -> Result<()> {
let format = self.detect_format(path)?;
let database = match format {
StorageFormat::Sofa => self.load_sofa_database(path).await?,
StorageFormat::Json => self.load_json_database(path).await?,
StorageFormat::Binary => self.load_binary_database(path).await?,
StorageFormat::Hdf5 => self.load_hdf5_database(path).await?,
};
let mut db = self.main_database.write().map_err(|e| {
Error::LegacyProcessing(format!(
"Failed to acquire write lock on HRTF database: {}",
e
))
})?;
*db = database;
if self.config.precompute_weights {
self.precompute_interpolation_weights(&mut db)?;
}
Ok(())
}
pub fn get_hrtf(&mut self, position: &Position3D) -> Result<HrtfMeasurement> {
let start_time = std::time::Instant::now();
let hrtf_pos = HrtfPosition::from_position3d(position);
let interpolated = {
let db = self.main_database.read().map_err(|e| {
Error::LegacyProcessing(format!(
"Failed to acquire read lock on HRTF database: {}",
e
))
})?;
if let Some(measurement) = db.measurements.get(&hrtf_pos) {
self.metrics.cache_hits += 1;
self.metrics.total_lookups += 1;
return Ok(measurement.clone());
}
self.interpolate_hrtf(&hrtf_pos, &db)?
};
self.metrics.interpolations += 1;
self.metrics.total_lookups += 1;
let elapsed = start_time.elapsed();
self.update_timing_metrics(elapsed.as_micros() as f64);
Ok(interpolated)
}
pub fn get_personalized_hrtf(
&self,
user_id: &str,
position: &Position3D,
) -> Result<HrtfMeasurement> {
let cache = self.personalized_cache.read().map_err(|e| {
Error::LegacyProcessing(format!(
"Failed to acquire read lock on personalized cache: {}",
e
))
})?;
let hrtf_pos = HrtfPosition::from_position3d(position);
if let Some(personalized) = cache.get(user_id) {
if let Some(measurement) = personalized.measurements.get(&hrtf_pos) {
return Ok(measurement.clone());
}
return self.interpolate_personalized_hrtf(personalized, &hrtf_pos);
}
Err(Error::hrtf("Personalized HRTF not found for user"))
}
pub fn create_personalized_hrtf(
&mut self,
user_id: String,
head_measurements: HeadMeasurements,
) -> Result<()> {
let adaptation_params = self.calculate_adaptation_parameters(&head_measurements)?;
let personalized_measurements = self.adapt_hrtf_measurements(&adaptation_params)?;
let personalized = PersonalizedHrtf {
user_id: user_id.clone(),
head_measurements,
measurements: personalized_measurements,
adaptation_params,
last_updated: "2025-07-23T00:00:00Z".to_string(),
};
let mut cache = self.personalized_cache.write().map_err(|e| {
Error::LegacyProcessing(format!(
"Failed to acquire write lock on personalized cache: {}",
e
))
})?;
cache.insert(user_id, personalized);
Ok(())
}
pub fn optimize_database(&mut self) -> Result<()> {
let mut db = self.main_database.write().map_err(|e| {
Error::LegacyProcessing(format!(
"Failed to acquire write lock on HRTF database: {}",
e
))
})?;
db.measurements
.retain(|_, measurement| measurement.quality_score >= 0.7);
self.precompute_interpolation_weights(&mut db)?;
if self.config.compression {
self.compress_measurements(&mut db)?;
}
Ok(())
}
pub fn get_statistics(&self) -> DatabaseStatistics {
let db = self
.main_database
.read()
.expect("Failed to acquire read lock on HRTF database for statistics");
let cache = self
.personalized_cache
.read()
.expect("Failed to acquire read lock on personalized cache for statistics");
DatabaseStatistics {
total_measurements: db.measurements.len(),
personalized_users: cache.len(),
cache_hit_rate: if self.metrics.total_lookups > 0 {
(self.metrics.cache_hits as f64 / self.metrics.total_lookups as f64) * 100.0
} else {
0.0
},
avg_lookup_time_us: self.metrics.avg_lookup_time_us,
memory_usage_mb: self.metrics.memory_usage_bytes as f64 / (1024.0 * 1024.0),
interpolation_rate: if self.metrics.total_lookups > 0 {
(self.metrics.interpolations as f64 / self.metrics.total_lookups as f64) * 100.0
} else {
0.0
},
}
}
fn detect_format(&self, path: &Path) -> Result<StorageFormat> {
match path.extension().and_then(|ext| ext.to_str()) {
Some("sofa") => Ok(StorageFormat::Sofa),
Some("json") => Ok(StorageFormat::Json),
Some("bin") => Ok(StorageFormat::Binary),
Some("h5") | Some("hdf5") => Ok(StorageFormat::Hdf5),
_ => Err(Error::hrtf("Unknown database format")),
}
}
async fn load_sofa_database(&self, _path: &Path) -> Result<HrtfDatabase> {
Ok(HrtfDatabase::new())
}
async fn load_json_database(&self, path: &Path) -> Result<HrtfDatabase> {
let content = tokio::fs::read_to_string(path).await?;
let database: HrtfDatabase = serde_json::from_str(&content)?;
Ok(database)
}
async fn load_binary_database(&self, _path: &Path) -> Result<HrtfDatabase> {
Ok(HrtfDatabase::new())
}
async fn load_hdf5_database(&self, _path: &Path) -> Result<HrtfDatabase> {
Ok(HrtfDatabase::new())
}
fn interpolate_hrtf(
&self,
position: &HrtfPosition,
db: &HrtfDatabase,
) -> Result<HrtfMeasurement> {
match self.config.interpolation_method {
InterpolationMethod::NearestNeighbor => {
self.nearest_neighbor_interpolation(position, db)
}
InterpolationMethod::Bilinear => self.bilinear_interpolation(position, db),
InterpolationMethod::SphericalSpline => {
self.spherical_spline_interpolation(position, db)
}
InterpolationMethod::Barycentric => self.barycentric_interpolation(position, db),
InterpolationMethod::RadialBasisFunction => self.rbf_interpolation(position, db),
}
}
fn nearest_neighbor_interpolation(
&self,
position: &HrtfPosition,
db: &HrtfDatabase,
) -> Result<HrtfMeasurement> {
let mut closest_distance = f32::INFINITY;
let mut closest_measurement = None;
for (pos, measurement) in &db.measurements {
let distance = self.calculate_angular_distance(position, pos);
if distance < closest_distance {
closest_distance = distance;
closest_measurement = Some(measurement);
}
}
closest_measurement
.cloned()
.ok_or_else(|| Error::hrtf("No HRTF measurements available"))
}
fn bilinear_interpolation(
&self,
position: &HrtfPosition,
db: &HrtfDatabase,
) -> Result<HrtfMeasurement> {
let neighbors = self.find_interpolation_neighbors(position, db, 4)?;
if neighbors.is_empty() {
return Err(Error::hrtf("No neighbors found for interpolation"));
}
self.weighted_interpolation(&neighbors, db)
}
fn spherical_spline_interpolation(
&self,
_position: &HrtfPosition,
_db: &HrtfDatabase,
) -> Result<HrtfMeasurement> {
Err(Error::hrtf(
"Spherical spline interpolation not implemented",
))
}
fn barycentric_interpolation(
&self,
_position: &HrtfPosition,
_db: &HrtfDatabase,
) -> Result<HrtfMeasurement> {
Err(Error::hrtf("Barycentric interpolation not implemented"))
}
fn rbf_interpolation(
&self,
_position: &HrtfPosition,
_db: &HrtfDatabase,
) -> Result<HrtfMeasurement> {
Err(Error::hrtf("RBF interpolation not implemented"))
}
fn find_interpolation_neighbors(
&self,
position: &HrtfPosition,
db: &HrtfDatabase,
count: usize,
) -> Result<Vec<InterpolationWeight>> {
let mut neighbors: Vec<_> = db
.measurements
.keys()
.map(|pos| {
let distance = self.calculate_angular_distance(position, pos);
InterpolationWeight {
position: *pos,
weight: 0.0,
distance,
}
})
.collect();
neighbors.sort_by(|a, b| {
a.distance
.partial_cmp(&b.distance)
.unwrap_or(std::cmp::Ordering::Equal)
});
neighbors.truncate(count);
let total_weight: f32 = neighbors.iter().map(|n| 1.0 / (n.distance + 1e-6)).sum();
for neighbor in &mut neighbors {
neighbor.weight = (1.0 / (neighbor.distance + 1e-6)) / total_weight;
}
Ok(neighbors)
}
fn weighted_interpolation(
&self,
neighbors: &[InterpolationWeight],
db: &HrtfDatabase,
) -> Result<HrtfMeasurement> {
if neighbors.is_empty() {
return Err(Error::hrtf("No neighbors for interpolation"));
}
let first_measurement = db
.measurements
.get(&neighbors[0].position)
.ok_or_else(|| Error::hrtf("Reference measurement not found"))?;
let ir_length = first_measurement.left_ir.len();
let mut left_ir = vec![0.0; ir_length];
let mut right_ir = vec![0.0; ir_length];
let mut total_itd = 0.0;
let mut total_ild = 0.0;
for neighbor in neighbors {
if let Some(measurement) = db.measurements.get(&neighbor.position) {
for i in 0..ir_length.min(measurement.left_ir.len()) {
left_ir[i] += measurement.left_ir[i] * neighbor.weight;
right_ir[i] += measurement.right_ir[i] * neighbor.weight;
}
total_itd += measurement.itd_samples * neighbor.weight;
total_ild += measurement.ild_db * neighbor.weight;
}
}
Ok(HrtfMeasurement {
left_ir,
right_ir,
quality_score: 1.0, itd_samples: total_itd,
ild_db: total_ild,
frequency_response: first_measurement.frequency_response.clone(), })
}
fn calculate_angular_distance(&self, pos1: &HrtfPosition, pos2: &HrtfPosition) -> f32 {
let az1 = pos1.azimuth as f32 * std::f32::consts::PI / 180.0;
let el1 = pos1.elevation as f32 * std::f32::consts::PI / 180.0;
let az2 = pos2.azimuth as f32 * std::f32::consts::PI / 180.0;
let el2 = pos2.elevation as f32 * std::f32::consts::PI / 180.0;
let delta_az = az2 - az1;
let delta_el = el2 - el1;
let a =
(delta_el / 2.0).sin().powi(2) + el1.cos() * el2.cos() * (delta_az / 2.0).sin().powi(2);
2.0 * a.sqrt().asin()
}
fn precompute_interpolation_weights(&self, _db: &mut HrtfDatabase) -> Result<()> {
Ok(())
}
fn interpolate_personalized_hrtf(
&self,
_personalized: &PersonalizedHrtf,
_position: &HrtfPosition,
) -> Result<HrtfMeasurement> {
Err(Error::hrtf(
"Personalized HRTF interpolation not implemented",
))
}
fn calculate_adaptation_parameters(
&self,
_head_measurements: &HeadMeasurements,
) -> Result<AdaptationParameters> {
Ok(AdaptationParameters {
head_scaling: 1.0,
pinna_scaling: 1.0,
itd_adjustment: 1.0,
ild_adjustment: 1.0,
frequency_adjustments: Vec::new(),
})
}
fn adapt_hrtf_measurements(
&self,
_params: &AdaptationParameters,
) -> Result<HashMap<HrtfPosition, HrtfMeasurement>> {
Ok(HashMap::new())
}
fn compress_measurements(&self, _db: &mut HrtfDatabase) -> Result<()> {
Ok(())
}
fn update_timing_metrics(&mut self, time_us: f64) {
let alpha = 0.1; self.metrics.avg_lookup_time_us =
alpha * time_us + (1.0 - alpha) * self.metrics.avg_lookup_time_us;
}
}
impl HrtfDatabase {
pub fn new() -> Self {
Self {
metadata: DatabaseMetadata {
name: "Default HRTF Database".to_string(),
version: "1.0.0".to_string(),
sample_rate: 48000,
hrtf_length: 512,
conditions: MeasurementConditions {
room: RoomCharacteristics {
room_type: "anechoic".to_string(),
dimensions: (5.0, 4.0, 3.0),
rt60: 0.05,
noise_floor: -40.0,
},
equipment: EquipmentInfo {
microphone: "Generic".to_string(),
speaker: "Generic".to_string(),
audio_interface: "Generic".to_string(),
software: "VoiRS".to_string(),
},
methodology: "Standard HRTF measurement".to_string(),
quality_metrics: QualityMetrics {
snr: 60.0,
thd: 0.1,
frequency_deviation: 1.0,
phase_coherence: 0.95,
},
},
subject_demographics: SubjectDemographics {
age: 25,
gender: "Mixed".to_string(),
head_measurements: HeadMeasurements {
width: 15.5,
depth: 19.0,
circumference: 56.0,
interaural_distance: 14.5,
pinna_left: PinnaMeasurements {
height: 6.2,
width: 3.5,
depth: 2.1,
concha_depth: 1.2,
concha_volume: 2.8,
},
pinna_right: PinnaMeasurements {
height: 6.2,
width: 3.5,
depth: 2.1,
concha_depth: 1.2,
concha_volume: 2.8,
},
},
hearing_assessment: HearingAssessment {
audiogram: HashMap::new(),
status: "Normal".to_string(),
hearing_aid: false,
},
},
created: "2025-07-23T00:00:00Z".to_string(),
modified: "2025-07-23T00:00:00Z".to_string(),
},
measurements: HashMap::new(),
interpolation_cache: HashMap::new(),
distance_hrtfs: HashMap::new(),
}
}
}
impl Default for HrtfDatabase {
fn default() -> Self {
Self::new()
}
}
impl HrtfPosition {
pub fn from_position3d(pos: &Position3D) -> Self {
let azimuth = pos.z.atan2(pos.x).to_degrees() as i16;
let elevation = pos
.y
.atan2((pos.x.powi(2) + pos.z.powi(2)).sqrt())
.to_degrees() as i16;
let distance_cm = ((pos.x.powi(2) + pos.y.powi(2) + pos.z.powi(2)).sqrt() * 100.0) as u16;
Self {
azimuth,
elevation,
distance_cm,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DatabaseStatistics {
pub total_measurements: usize,
pub personalized_users: usize,
pub cache_hit_rate: f64,
pub avg_lookup_time_us: f64,
pub memory_usage_mb: f64,
pub interpolation_rate: f64,
}
impl Default for DatabaseConfig {
fn default() -> Self {
Self {
cache_size: 1000,
interpolation_method: InterpolationMethod::Bilinear,
distance_interpolation: true,
precompute_weights: true,
compression: false,
storage_format: StorageFormat::Json,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_database_creation() {
let config = DatabaseConfig::default();
let manager = HrtfDatabaseManager::new(config);
assert!(manager.is_ok());
}
#[test]
fn test_hrtf_position_conversion() {
let pos3d = Position3D::new(1.0, 0.0, 0.0);
let hrtf_pos = HrtfPosition::from_position3d(&pos3d);
assert_eq!(hrtf_pos.azimuth, 0);
assert_eq!(hrtf_pos.elevation, 0);
assert_eq!(hrtf_pos.distance_cm, 100);
}
#[test]
fn test_database_metadata() {
let db = HrtfDatabase::new();
assert_eq!(db.metadata.sample_rate, 48000);
assert_eq!(db.metadata.hrtf_length, 512);
assert!(!db.metadata.name.is_empty());
}
#[test]
fn test_interpolation_methods() {
let methods = [
InterpolationMethod::NearestNeighbor,
InterpolationMethod::Bilinear,
InterpolationMethod::SphericalSpline,
InterpolationMethod::Barycentric,
InterpolationMethod::RadialBasisFunction,
];
for method in &methods {
let config = DatabaseConfig {
interpolation_method: *method,
..Default::default()
};
let manager = HrtfDatabaseManager::new(config);
assert!(manager.is_ok());
}
}
#[test]
fn test_angular_distance_calculation() {
let config = DatabaseConfig::default();
let manager = HrtfDatabaseManager::new(config)
.expect("Failed to create HRTF database manager for angular distance test");
let pos1 = HrtfPosition {
azimuth: 0,
elevation: 0,
distance_cm: 100,
};
let pos2 = HrtfPosition {
azimuth: 90,
elevation: 0,
distance_cm: 100,
};
let distance = manager.calculate_angular_distance(&pos1, &pos2);
assert!(distance > 0.0);
assert!(distance < std::f32::consts::PI);
}
#[test]
fn test_database_statistics() {
let config = DatabaseConfig::default();
let manager = HrtfDatabaseManager::new(config).unwrap();
let stats = manager.get_statistics();
assert_eq!(stats.total_measurements, 0);
assert_eq!(stats.personalized_users, 0);
assert_eq!(stats.cache_hit_rate, 0.0);
}
}