use std::error::Error;
use std::fs::File;
use std::io::{BufRead, BufReader};
use std::path::Path;
use ndarray::Array1;
use crate::Curve;
#[derive(Debug, Clone, PartialEq)]
pub struct MicPhaseCalibration {
pub freq: Array1<f64>,
pub mag_db: Array1<f64>,
pub phase_deg: Array1<f64>,
pub coherence: Array1<f64>,
}
impl MicPhaseCalibration {
pub fn identity(freq: Array1<f64>) -> Self {
let n = freq.len();
Self {
freq,
mag_db: Array1::zeros(n),
phase_deg: Array1::zeros(n),
coherence: Array1::ones(n),
}
}
pub fn sample_at(&self, freq_hz: f64) -> Option<(f64, f64, f64)> {
let n = self.freq.len();
if n == 0 {
return None;
}
if !freq_hz.is_finite() {
return None;
}
if freq_hz <= self.freq[0] {
return Some((self.mag_db[0], self.phase_deg[0], self.coherence[0]));
}
if freq_hz >= self.freq[n - 1] {
return Some((
self.mag_db[n - 1],
self.phase_deg[n - 1],
self.coherence[n - 1],
));
}
let idx = self
.freq
.as_slice()
.expect("contiguous")
.partition_point(|&f| f < freq_hz);
let x0 = self.freq[idx - 1];
let x1 = self.freq[idx];
let dx = x1 - x0;
if dx.abs() < f64::EPSILON {
return Some((self.mag_db[idx], self.phase_deg[idx], self.coherence[idx]));
}
let t = (freq_hz - x0) / dx;
Some((
self.mag_db[idx - 1] * (1.0 - t) + self.mag_db[idx] * t,
self.phase_deg[idx - 1] * (1.0 - t) + self.phase_deg[idx] * t,
self.coherence[idx - 1] * (1.0 - t) + self.coherence[idx] * t,
))
}
pub fn apply_to_curve(&self, curve: &mut Curve) {
let n = curve.freq.len();
if n == 0 || curve.spl.len() != n {
return;
}
for i in 0..n {
let (mag, phase, coh) = match self.sample_at(curve.freq[i]) {
Some(tuple) => tuple,
None => continue,
};
curve.spl[i] -= mag;
if let Some(ref mut phase_arr) = curve.phase
&& phase_arr.len() == n
{
phase_arr[i] -= phase;
}
if let Some(ref mut coh_arr) = curve.coherence
&& coh_arr.len() == n
{
coh_arr[i] *= coh;
}
}
}
}
pub fn load_mic_phase_calibration(path: &Path) -> Result<MicPhaseCalibration, Box<dyn Error>> {
let file = File::open(path)?;
let reader = BufReader::new(file);
let mut freq_col: Option<usize> = None;
let mut mag_col: Option<usize> = None;
let mut phase_col: Option<usize> = None;
let mut coh_col: Option<usize> = None;
let mut header_parsed = false;
let mut freqs = Vec::new();
let mut mags = Vec::new();
let mut phases = Vec::new();
let mut cohs = Vec::new();
for (line_num, line) in reader.lines().enumerate() {
let line = line?;
let trimmed = line.trim();
if trimmed.is_empty() || trimmed.starts_with('#') || trimmed.starts_with("//") {
continue;
}
let parts: Vec<&str> = if trimmed.contains(',') {
trimmed.split(',').map(|s| s.trim()).collect()
} else {
trimmed.split_whitespace().collect()
};
if !header_parsed {
for (idx, col_name) in parts.iter().enumerate() {
let lower = col_name.to_lowercase();
if coh_col.is_none() && lower == "coherence" {
coh_col = Some(idx);
} else if phase_col.is_none()
&& (lower.contains("phase") || lower == "phase_deg")
{
phase_col = Some(idx);
} else if freq_col.is_none()
&& (lower.contains("freq") || lower == "hz" || lower == "frequency_hz")
{
freq_col = Some(idx);
} else if mag_col.is_none()
&& (lower.contains("mag_db")
|| lower.contains("magnitude")
|| lower.contains("spl")
|| lower == "db"
|| lower == "spl_db")
{
mag_col = Some(idx);
}
}
if freq_col.is_none()
|| mag_col.is_none()
|| phase_col.is_none()
|| coh_col.is_none()
{
return Err(format!(
"mic phase calibration at {path:?} must have named columns for \
frequency / magnitude_db / phase_deg / coherence; got header {parts:?}"
)
.into());
}
header_parsed = true;
continue;
}
let freq_idx = freq_col.unwrap();
let mag_idx = mag_col.unwrap();
let phase_idx = phase_col.unwrap();
let coh_idx = coh_col.unwrap();
let max_idx = freq_idx.max(mag_idx).max(phase_idx).max(coh_idx);
if parts.len() <= max_idx {
continue; }
let (Ok(f), Ok(m), Ok(p), Ok(c)) = (
parts[freq_idx].parse::<f64>(),
parts[mag_idx].parse::<f64>(),
parts[phase_idx].parse::<f64>(),
parts[coh_idx].parse::<f64>(),
) else {
log::debug!(
"[mic_phase_calibration] Skipping malformed row {} in {path:?}: {:?}",
line_num + 1,
parts
);
continue;
};
if !(f.is_finite() && m.is_finite() && p.is_finite() && c.is_finite()) {
continue;
}
freqs.push(f);
mags.push(m);
phases.push(p);
cohs.push(c);
}
if freqs.is_empty() {
return Err(format!(
"mic phase calibration at {path:?} contained no valid data rows"
)
.into());
}
for pair in freqs.windows(2) {
if pair[0] >= pair[1] {
return Err(format!(
"mic phase calibration at {path:?} frequencies are not strictly increasing \
(found {} before {})",
pair[0], pair[1]
)
.into());
}
}
Ok(MicPhaseCalibration {
freq: Array1::from_vec(freqs),
mag_db: Array1::from_vec(mags),
phase_deg: Array1::from_vec(phases),
coherence: Array1::from_vec(cohs),
})
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use tempfile::NamedTempFile;
fn write_cal_csv(csv: &str) -> NamedTempFile {
let mut f = NamedTempFile::new().unwrap();
f.write_all(csv.as_bytes()).unwrap();
f.flush().unwrap();
f
}
#[test]
fn loads_canonical_four_column_csv() {
let csv = "\
frequency_hz,mag_db,phase_deg,coherence
20.0,2.0,-10.0,0.95
50.0,1.0,-5.0,0.98
200.0,0.0,0.0,0.99
2000.0,-0.5,2.0,0.99
20000.0,-3.0,15.0,0.90
";
let f = write_cal_csv(csv);
let cal = load_mic_phase_calibration(f.path()).unwrap();
assert_eq!(cal.freq.len(), 5);
assert!((cal.freq[0] - 20.0).abs() < 1e-9);
assert!((cal.mag_db[0] - 2.0).abs() < 1e-9);
assert!((cal.phase_deg[1] + 5.0).abs() < 1e-9);
assert!((cal.coherence[4] - 0.90).abs() < 1e-9);
}
#[test]
fn column_order_is_header_driven() {
let csv = "\
phase_deg,coherence,frequency_hz,mag_db
-10.0,0.95,20.0,2.0
-5.0,0.98,50.0,1.0
";
let f = write_cal_csv(csv);
let cal = load_mic_phase_calibration(f.path()).unwrap();
assert!((cal.freq[0] - 20.0).abs() < 1e-9);
assert!((cal.phase_deg[0] + 10.0).abs() < 1e-9);
assert!((cal.mag_db[0] - 2.0).abs() < 1e-9);
}
#[test]
fn missing_column_rejects_file() {
let csv = "\
frequency_hz,mag_db,phase_deg
20.0,2.0,-10.0
";
let f = write_cal_csv(csv);
assert!(load_mic_phase_calibration(f.path()).is_err());
}
#[test]
fn non_monotonic_frequencies_reject_file() {
let csv = "\
frequency_hz,mag_db,phase_deg,coherence
200.0,0.0,0.0,0.99
20.0,2.0,-10.0,0.95
";
let f = write_cal_csv(csv);
let err = load_mic_phase_calibration(f.path())
.unwrap_err()
.to_string();
assert!(err.contains("not strictly increasing"), "got: {err}");
}
#[test]
fn malformed_row_is_dropped_but_loader_still_succeeds() {
let csv = "\
frequency_hz,mag_db,phase_deg,coherence
20.0,2.0,-10.0,0.95
50.0,not-a-number,-5.0,0.98
200.0,0.0,0.0,0.99
";
let f = write_cal_csv(csv);
let cal = load_mic_phase_calibration(f.path()).unwrap();
assert_eq!(cal.freq.len(), 2);
assert!((cal.freq[1] - 200.0).abs() < 1e-9);
}
#[test]
fn sample_at_exact_node_matches_stored_value() {
let cal = MicPhaseCalibration {
freq: Array1::from_vec(vec![20.0, 200.0, 2000.0]),
mag_db: Array1::from_vec(vec![2.0, 0.0, -3.0]),
phase_deg: Array1::from_vec(vec![-10.0, 0.0, 5.0]),
coherence: Array1::from_vec(vec![0.95, 0.99, 0.90]),
};
let (m, p, c) = cal.sample_at(200.0).unwrap();
assert!((m - 0.0).abs() < 1e-9);
assert!((p - 0.0).abs() < 1e-9);
assert!((c - 0.99).abs() < 1e-9);
}
#[test]
fn sample_at_midpoint_interpolates() {
let cal = MicPhaseCalibration {
freq: Array1::from_vec(vec![100.0, 200.0]),
mag_db: Array1::from_vec(vec![0.0, 4.0]),
phase_deg: Array1::from_vec(vec![0.0, 20.0]),
coherence: Array1::from_vec(vec![1.0, 0.5]),
};
let (m, p, c) = cal.sample_at(150.0).unwrap();
assert!((m - 2.0).abs() < 1e-9, "mag @ 150 Hz should be 2 dB, got {m}");
assert!(
(p - 10.0).abs() < 1e-9,
"phase @ 150 Hz should be 10°, got {p}"
);
assert!(
(c - 0.75).abs() < 1e-9,
"coherence @ 150 Hz should be 0.75, got {c}"
);
}
#[test]
fn sample_at_below_min_returns_first() {
let cal = MicPhaseCalibration {
freq: Array1::from_vec(vec![50.0, 500.0]),
mag_db: Array1::from_vec(vec![1.0, 2.0]),
phase_deg: Array1::from_vec(vec![-10.0, 5.0]),
coherence: Array1::from_vec(vec![0.9, 0.95]),
};
let (m, _, _) = cal.sample_at(10.0).unwrap();
assert!((m - 1.0).abs() < 1e-9);
}
#[test]
fn sample_at_above_max_returns_last() {
let cal = MicPhaseCalibration {
freq: Array1::from_vec(vec![50.0, 500.0]),
mag_db: Array1::from_vec(vec![1.0, 2.0]),
phase_deg: Array1::from_vec(vec![-10.0, 5.0]),
coherence: Array1::from_vec(vec![0.9, 0.95]),
};
let (m, _, _) = cal.sample_at(50_000.0).unwrap();
assert!((m - 2.0).abs() < 1e-9);
}
#[test]
fn identity_is_transparent() {
let freq = Array1::from_vec(vec![20.0, 200.0, 2000.0]);
let cal = MicPhaseCalibration::identity(freq.clone());
let mut curve = Curve {
freq: freq.clone(),
spl: Array1::from_vec(vec![60.0, 80.0, 90.0]),
phase: Some(Array1::from_vec(vec![-30.0, 0.0, 45.0])),
coherence: Some(Array1::from_vec(vec![0.98, 0.99, 0.97])),
..Default::default()
};
let before = curve.clone();
cal.apply_to_curve(&mut curve);
assert_eq!(before.spl, curve.spl);
assert_eq!(before.phase, curve.phase);
assert_eq!(before.coherence, curve.coherence);
}
#[test]
fn apply_subtracts_mag_and_phase() {
let freq = Array1::from_vec(vec![20.0, 200.0, 2000.0]);
let cal = MicPhaseCalibration {
freq: freq.clone(),
mag_db: Array1::from_vec(vec![2.0, 0.0, -3.0]),
phase_deg: Array1::from_vec(vec![-10.0, 0.0, 5.0]),
coherence: Array1::from_vec(vec![0.5, 1.0, 0.8]),
};
let mut curve = Curve {
freq: freq.clone(),
spl: Array1::from_vec(vec![60.0, 80.0, 90.0]),
phase: Some(Array1::from_vec(vec![0.0, 0.0, 0.0])),
coherence: Some(Array1::from_vec(vec![1.0, 1.0, 1.0])),
..Default::default()
};
cal.apply_to_curve(&mut curve);
assert!((curve.spl[0] - 58.0).abs() < 1e-9); assert!((curve.spl[1] - 80.0).abs() < 1e-9); assert!((curve.spl[2] - 93.0).abs() < 1e-9); let phase = curve.phase.as_ref().unwrap();
assert!((phase[0] - 10.0).abs() < 1e-9); assert!((phase[1] - 0.0).abs() < 1e-9);
assert!((phase[2] + 5.0).abs() < 1e-9); let coh = curve.coherence.as_ref().unwrap();
assert!((coh[0] - 0.5).abs() < 1e-9);
assert!((coh[1] - 1.0).abs() < 1e-9);
assert!((coh[2] - 0.8).abs() < 1e-9);
}
#[test]
fn apply_skips_when_phase_or_coherence_absent() {
let freq = Array1::from_vec(vec![20.0, 200.0]);
let cal = MicPhaseCalibration {
freq: freq.clone(),
mag_db: Array1::from_vec(vec![2.0, 0.0]),
phase_deg: Array1::from_vec(vec![-10.0, 0.0]),
coherence: Array1::from_vec(vec![0.5, 1.0]),
};
let mut curve = Curve {
freq,
spl: Array1::from_vec(vec![60.0, 80.0]),
phase: None,
coherence: None,
..Default::default()
};
cal.apply_to_curve(&mut curve);
assert_eq!(curve.phase, None);
assert_eq!(curve.coherence, None);
assert!((curve.spl[0] - 58.0).abs() < 1e-9);
}
}