use hound::WavReader;
use math_audio_dsp::analysis::cross_correlate_envelope;
use std::path::Path;
#[derive(Debug, Clone)]
pub struct ArrivalTimeResult {
pub arrival_samples: usize,
pub arrival_ms: f64,
pub sample_rate: u32,
pub peak_amplitude: f32,
}
pub fn find_arrival_time(
wav_path: &Path,
threshold_db: Option<f64>,
) -> Result<ArrivalTimeResult, String> {
let threshold_db = threshold_db.unwrap_or(-40.0);
let mut reader = WavReader::open(wav_path)
.map_err(|e| format!("Failed to open WAV file {:?}: {}", wav_path, e))?;
let spec = reader.spec();
let sample_rate = spec.sample_rate;
let channels = spec.channels as usize;
let samples: Vec<f32> = match spec.sample_format {
hound::SampleFormat::Int => {
let bits = spec.bits_per_sample;
let max_val = (1u32 << (bits - 1)) as f32;
reader
.samples::<i32>()
.enumerate()
.filter(|(i, _)| i % channels == 0) .map(|(_, s)| s.map(|v| v as f32 / max_val))
.collect::<Result<Vec<_>, _>>()
.map_err(|e| format!("Failed to read samples: {}", e))?
}
hound::SampleFormat::Float => reader
.samples::<f32>()
.enumerate()
.filter(|(i, _)| i % channels == 0)
.map(|(_, s)| s)
.collect::<Result<Vec<_>, _>>()
.map_err(|e| format!("Failed to read samples: {}", e))?,
};
if samples.is_empty() {
return Err("WAV file contains no samples".to_string());
}
let peak_amplitude = samples.iter().map(|&s| s.abs()).fold(0.0_f32, f32::max);
if peak_amplitude < 1e-6 {
return Err("Signal appears to be silent (peak amplitude < -120 dB)".to_string());
}
let noise_samples = (sample_rate as usize / 100)
.min(samples.len() / 100)
.max(10);
let noise_floor: f32 = samples[..noise_samples]
.iter()
.map(|&s| s.abs())
.fold(0.0_f32, f32::max);
let threshold_from_peak = peak_amplitude * 10.0_f32.powf(threshold_db as f32 / 20.0);
let threshold_from_noise = noise_floor * 10.0; let threshold_linear = threshold_from_peak.max(threshold_from_noise).max(1e-5);
let mut arrival_idx = 0;
for (i, &sample) in samples.iter().enumerate() {
if sample.abs() >= threshold_linear {
arrival_idx = i;
break;
}
}
let arrival_ms = arrival_idx as f64 * 1000.0 / sample_rate as f64;
Ok(ArrivalTimeResult {
arrival_samples: arrival_idx,
arrival_ms,
sample_rate,
peak_amplitude,
})
}
pub fn estimate_arrival_from_phase(
curve: &crate::Curve,
min_freq: f64,
max_freq: f64,
) -> Option<f64> {
use std::f64::consts::PI;
let phase = curve.phase.as_ref()?;
let unwrapped = super::phase_utils::unwrap_phase_degrees(phase);
let points: Vec<(f64, f64)> = curve
.freq
.iter()
.zip(unwrapped.iter())
.filter(|&(&f, _)| f >= min_freq && f <= max_freq)
.map(|(&f, &p)| (f, p))
.collect();
if points.len() < 5 {
return None;
}
let n = points.len() as f64;
let sum_f: f64 = points.iter().map(|(f, _)| f).sum();
let sum_phi: f64 = points.iter().map(|(_, p)| p.to_radians()).sum();
let sum_f2: f64 = points.iter().map(|(f, _)| f * f).sum();
let sum_f_phi: f64 = points.iter().map(|(f, p)| f * p.to_radians()).sum();
let denom = n * sum_f2 - sum_f * sum_f;
if denom.abs() < 1e-12 {
return None;
}
let slope = (n * sum_f_phi - sum_f * sum_phi) / denom;
let delay_ms = -slope / (2.0 * PI) * 1000.0;
if delay_ms > 0.0 && delay_ms < 500.0 {
Some(delay_ms)
} else {
None
}
}
pub fn calculate_alignment_delays(
arrival_times: &std::collections::HashMap<String, f64>,
) -> std::collections::HashMap<String, f64> {
if arrival_times.is_empty() {
return std::collections::HashMap::new();
}
let max_arrival = arrival_times
.values()
.copied()
.fold(f64::NEG_INFINITY, f64::max);
arrival_times
.iter()
.map(|(name, &arrival)| (name.clone(), max_arrival - arrival))
.collect()
}
#[derive(Debug, Clone)]
pub struct ProbeDelayResult {
pub arrival_ms: f64,
pub arrival_samples: usize,
pub gain_linear: f64,
pub gain_db: f64,
pub detection_snr_db: f64,
}
pub fn detect_delay_with_probe(
probe: &[f32],
recorded: &[f32],
sample_rate: u32,
) -> Result<ProbeDelayResult, String> {
let auto_result = cross_correlate_envelope(probe, probe, sample_rate)?;
let auto_peak = auto_result.peak_value as f64;
detect_delay_with_probe_inner(probe, recorded, sample_rate, auto_peak)
}
fn detect_delay_with_probe_inner(
probe: &[f32],
recorded: &[f32],
sample_rate: u32,
auto_peak: f64,
) -> Result<ProbeDelayResult, String> {
let result = cross_correlate_envelope(probe, recorded, sample_rate)?;
let gain_linear = if auto_peak > 1e-10 {
result.peak_value as f64 / auto_peak
} else {
0.0
};
let gain_db = if gain_linear > 1e-10 {
20.0 * gain_linear.log10()
} else {
-120.0
};
let mut sorted_env = result.envelope.to_vec();
sorted_env.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let median = if sorted_env.is_empty() {
1e-10
} else {
sorted_env[sorted_env.len() / 2].max(1e-10) as f64
};
let detection_snr_db = 20.0 * (result.peak_value as f64 / median).log10();
Ok(ProbeDelayResult {
arrival_ms: result.arrival_ms,
arrival_samples: result.peak_sample,
gain_linear,
gain_db,
detection_snr_db,
})
}
pub fn detect_delays_multi_channel(
probe: &[f32],
recorded: &[f32],
channel_offsets: &[usize],
segment_length: usize,
sample_rate: u32,
) -> Result<Vec<ProbeDelayResult>, String> {
let mut results = Vec::with_capacity(channel_offsets.len());
let auto_result = cross_correlate_envelope(probe, probe, sample_rate)?;
let auto_peak = auto_result.peak_value as f64;
for (i, &offset) in channel_offsets.iter().enumerate() {
if offset >= recorded.len() {
return Err(format!(
"Channel {} offset {} exceeds recording length {}",
i,
offset,
recorded.len()
));
}
let end = (offset.saturating_add(segment_length)).min(recorded.len());
let segment = &recorded[offset..end];
let channel_result =
detect_delay_with_probe_inner(probe, segment, sample_rate, auto_peak)?;
log::debug!(
"[detect_delays_multi_channel] Ch {}: arrival={:.3}ms, gain={:.1}dB, SNR={:.1}dB",
i,
channel_result.arrival_ms,
channel_result.gain_db,
channel_result.detection_snr_db
);
results.push(channel_result);
}
Ok(results)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_calculate_alignment_delays() {
let mut arrivals = std::collections::HashMap::new();
arrivals.insert("L".to_string(), 10.0);
arrivals.insert("R".to_string(), 12.0);
arrivals.insert("C".to_string(), 8.0);
let delays = calculate_alignment_delays(&arrivals);
assert!((delays["R"] - 0.0).abs() < 0.001);
assert!((delays["L"] - 2.0).abs() < 0.001);
assert!((delays["C"] - 4.0).abs() < 0.001);
}
#[test]
fn test_estimate_arrival_from_phase() {
use ndarray::Array1;
let tau_ms = 5.0_f64;
let tau_s = tau_ms / 1000.0;
let freqs: Vec<f64> = (20..=2000).step_by(10).map(|f| f as f64).collect();
let phase_deg: Vec<f64> = freqs.iter().map(|&f| -360.0 * f * tau_s).collect();
let curve = crate::Curve {
freq: Array1::from_vec(freqs),
spl: Array1::zeros(phase_deg.len()),
phase: Some(Array1::from_vec(phase_deg)),
};
let estimated = estimate_arrival_from_phase(&curve, 200.0, 2000.0);
assert!(
estimated.is_some(),
"Should recover arrival time from phase"
);
let estimated = estimated.unwrap();
assert!(
(estimated - tau_ms).abs() < 0.1,
"Expected ~{} ms, got {} ms",
tau_ms,
estimated
);
}
#[test]
fn test_estimate_arrival_from_phase_no_phase() {
use ndarray::Array1;
let curve = crate::Curve {
freq: Array1::linspace(20.0, 2000.0, 100),
spl: Array1::zeros(100),
phase: None,
};
assert!(estimate_arrival_from_phase(&curve, 200.0, 2000.0).is_none());
}
#[test]
fn test_calculate_alignment_delays_empty() {
let arrivals = std::collections::HashMap::new();
let delays = calculate_alignment_delays(&arrivals);
assert!(delays.is_empty());
}
#[test]
fn test_alignment_delays_three_speakers() {
let mut arrivals = std::collections::HashMap::new();
arrivals.insert("A".to_string(), 0.0);
arrivals.insert("B".to_string(), 2.0);
arrivals.insert("C".to_string(), 5.0);
let delays = calculate_alignment_delays(&arrivals);
assert!(
(delays["A"] - 5.0).abs() < 0.001,
"A should get 5ms delay, got {}",
delays["A"]
);
assert!(
(delays["B"] - 3.0).abs() < 0.001,
"B should get 3ms delay, got {}",
delays["B"]
);
assert!(
(delays["C"] - 0.0).abs() < 0.001,
"C should get 0ms delay, got {}",
delays["C"]
);
}
#[test]
fn test_estimate_arrival_linear_phase() {
use ndarray::Array1;
let tau_ms = 3.0;
let tau_s = tau_ms / 1000.0;
let freqs: Vec<f64> = (100..=5000).step_by(20).map(|f| f as f64).collect();
let phase_deg: Vec<f64> = freqs.iter().map(|&f| -360.0 * f * tau_s).collect();
let curve = crate::Curve {
freq: Array1::from_vec(freqs),
spl: Array1::zeros(phase_deg.len()),
phase: Some(Array1::from_vec(phase_deg)),
};
let estimated = estimate_arrival_from_phase(&curve, 200.0, 4000.0);
assert!(
estimated.is_some(),
"Should recover arrival time from linear phase"
);
let estimated = estimated.unwrap();
assert!(
(estimated - tau_ms).abs() < 0.1,
"Expected ~{} ms, got {} ms (error {:.3} ms)",
tau_ms,
estimated,
(estimated - tau_ms).abs()
);
}
#[test]
fn test_detect_delay_with_probe() {
let sr = 48000_u32;
let n = 4096;
let probe = math_audio_dsp::signals::gen_narrowband_probe(n, sr, 0.5, 42, 800.0, 2000.0);
let delay = 480_usize;
let atten = 0.4_f32;
let mut recorded = vec![0.0_f32; n + delay + 500];
for (i, &s) in probe.iter().enumerate() {
recorded[i + delay] += s * atten;
}
let result = detect_delay_with_probe(&probe, &recorded, sr).unwrap();
assert!(
(result.arrival_ms - 10.0).abs() < 0.2,
"Expected ~10ms arrival, got {:.3}ms",
result.arrival_ms
);
assert!(
result.detection_snr_db > 10.0,
"SNR should be high for clean signal, got {:.1}dB",
result.detection_snr_db
);
let expected_gain_db = 20.0 * (atten as f64).log10(); assert!(
(result.gain_db - expected_gain_db).abs() < 3.0,
"Expected gain ~{:.1}dB, got {:.1}dB",
expected_gain_db,
result.gain_db
);
}
#[test]
fn test_detect_delays_multi_channel() {
let sr = 48000_u32;
let n = 2048;
let probe = math_audio_dsp::signals::gen_narrowband_probe(n, sr, 0.5, 42, 800.0, 2000.0);
let segment_len = n + 1000; let silence_len = 1000;
let delays = [240_usize, 480, 120]; let attens = [0.5_f32, 0.3, 0.7];
let total_len = delays.len() * (segment_len + silence_len) + silence_len;
let mut recorded = vec![0.0_f32; total_len];
let mut offsets = Vec::new();
for (ch, (&d, &a)) in delays.iter().zip(attens.iter()).enumerate() {
let offset = silence_len + ch * (segment_len + silence_len);
offsets.push(offset);
for (i, &s) in probe.iter().enumerate() {
let idx = offset + d + i;
if idx < recorded.len() {
recorded[idx] += s * a;
}
}
}
let results =
detect_delays_multi_channel(&probe, &recorded, &offsets, segment_len, sr).unwrap();
assert_eq!(results.len(), 3);
let expected_ms = [5.0, 10.0, 2.5];
for (i, (result, &expected)) in results.iter().zip(expected_ms.iter()).enumerate() {
assert!(
(result.arrival_ms - expected).abs() < 0.5,
"Channel {}: expected ~{:.1}ms, got {:.3}ms",
i,
expected,
result.arrival_ms
);
}
}
}