use crate::Curve;
use log::info;
use math_audio_iir_fir::{Biquad, FirDesignConfig, FirPhase, PreRingingConfig};
use ndarray::Array1;
use super::phase_utils;
pub type PhaseDecomposition = (Array1<f64>, Array1<f64>, f64, Array1<f64>);
#[derive(Debug, Clone)]
pub struct MixedPhaseConfig {
pub max_fir_length_ms: f64,
pub pre_ringing_threshold_db: f64,
pub min_spatial_depth: f64,
pub phase_smoothing_octaves: f64,
}
impl Default for MixedPhaseConfig {
fn default() -> Self {
Self {
max_fir_length_ms: 10.0,
pre_ringing_threshold_db: -30.0,
min_spatial_depth: 0.5,
phase_smoothing_octaves: 1.0 / 6.0,
}
}
}
#[derive(Debug, Clone)]
pub struct MixedPhaseResult {
pub iir_filters: Vec<Biquad>,
pub fir_coefficients: Vec<f64>,
pub estimated_delay_ms: f64,
pub minimum_phase: Array1<f64>,
pub excess_phase: Array1<f64>,
pub residual_excess_phase: Array1<f64>,
pub fir_taps: usize,
}
pub fn decompose_phase(
measurement: &Curve,
config: &MixedPhaseConfig,
) -> Result<PhaseDecomposition, String> {
let measured_phase = measurement
.phase
.as_ref()
.ok_or("Mixed-phase correction requires phase data in measurements")?;
let unwrapped_phase = phase_utils::unwrap_phase_degrees(measured_phase);
let min_phase = phase_utils::reconstruct_minimum_phase(&measurement.freq, &measurement.spl);
let excess_phase = phase_utils::compute_excess_phase(&unwrapped_phase, &min_phase);
let (delay_ms, residual) =
phase_utils::estimate_delay_from_excess_phase(&measurement.freq, &excess_phase);
info!(
" Mixed-phase decomposition: delay={:.2} ms, residual phase range={:.1}°",
delay_ms,
residual.iter().cloned().fold(f64::NEG_INFINITY, f64::max)
- residual.iter().cloned().fold(f64::INFINITY, f64::min),
);
let smoothed_residual = if config.phase_smoothing_octaves > 0.0 {
smooth_phase_log_freq(&residual, &measurement.freq, config.phase_smoothing_octaves)
} else {
residual.clone()
};
Ok((min_phase, excess_phase, delay_ms, smoothed_residual))
}
pub fn generate_excess_phase_fir(
freq: &Array1<f64>,
residual_phase_deg: &Array1<f64>,
config: &MixedPhaseConfig,
sample_rate: f64,
) -> Vec<f64> {
generate_excess_phase_fir_with_depth(freq, residual_phase_deg, config, sample_rate, None)
}
pub fn generate_excess_phase_fir_with_depth(
freq: &Array1<f64>,
residual_phase_deg: &Array1<f64>,
config: &MixedPhaseConfig,
sample_rate: f64,
correction_depth: Option<&Array1<f64>>,
) -> Vec<f64> {
let n_taps = (config.max_fir_length_ms / 1000.0 * sample_rate).round() as usize;
let n_taps = if n_taps.is_multiple_of(2) {
n_taps + 1
} else {
n_taps
};
let n_taps = n_taps.max(31);
let correction_phase_deg: Vec<f64> = if let Some(depth) = correction_depth {
assert_eq!(
residual_phase_deg.len(),
depth.len(),
"correction_depth length ({}) must match residual_phase_deg length ({})",
depth.len(),
residual_phase_deg.len(),
);
residual_phase_deg
.iter()
.zip(depth.iter())
.map(|(&p, &d)| {
if d >= config.min_spatial_depth {
-p
} else {
0.0 }
})
.collect()
} else {
residual_phase_deg.iter().map(|&p| -p).collect()
};
let fir_config = FirDesignConfig {
n_taps,
sample_rate,
phase: FirPhase::Minimum, min_freq: freq[0],
max_freq: freq[freq.len() - 1],
pre_ringing: Some(PreRingingConfig {
threshold_db: config.pre_ringing_threshold_db,
max_time_s: config.max_fir_length_ms / 1000.0 / 2.0, }),
..Default::default()
};
let magnitude_db: Vec<f64> = vec![0.0; freq.len()];
generate_phase_only_fir(
freq.as_slice().unwrap(),
&magnitude_db,
&correction_phase_deg,
&fir_config,
)
}
fn generate_phase_only_fir(
freqs: &[f64],
_magnitude_db: &[f64],
phase_deg: &[f64],
config: &FirDesignConfig,
) -> Vec<f64> {
use num_complex::Complex64;
use rustfft::FftPlanner;
let n_taps = config.n_taps;
let sample_rate = config.sample_rate;
let fft_size = (n_taps * 4).max(4096).next_power_of_two();
let n_bins = fft_size / 2 + 1;
let freq_step = sample_rate / fft_size as f64;
let linear_freqs: Vec<f64> = (0..n_bins).map(|i| i as f64 * freq_step).collect();
let interp_phase = interpolate_phase_log_space(freqs, phase_deg, &linear_freqs);
let mut spectrum: Vec<Complex64> = interp_phase
.iter()
.map(|&phase| {
let phi = phase.to_radians();
Complex64::from_polar(1.0, phi)
})
.collect();
spectrum[0] = Complex64::new(1.0, 0.0);
if n_bins > 1 {
spectrum[n_bins - 1] = Complex64::new(spectrum[n_bins - 1].re, 0.0);
}
let mut full_spectrum: Vec<Complex64> = Vec::with_capacity(fft_size);
full_spectrum.extend_from_slice(&spectrum);
for i in (1..n_bins - 1).rev() {
full_spectrum.push(spectrum[i].conj());
}
let mut planner = FftPlanner::new();
let ifft = planner.plan_fft_inverse(fft_size);
ifft.process(&mut full_spectrum);
let ir: Vec<f64> = full_spectrum
.iter()
.map(|c| c.re / fft_size as f64)
.collect();
let center = n_taps / 2;
let mut final_ir = vec![0.0; n_taps];
for (i, val) in final_ir.iter_mut().enumerate() {
let shift = i as isize - center as isize;
let ir_idx = if shift < 0 {
fft_size as isize + shift
} else {
shift
};
*val = ir[ir_idx as usize];
}
let window =
math_audio_iir_fir::generate_window(n_taps, math_audio_iir_fir::WindowType::Hann, 0.0);
for (x, w) in final_ir.iter_mut().zip(window.iter()) {
*x *= w;
}
if let Some(pr_config) = &config.pre_ringing {
math_audio_iir_fir::suppress_pre_ringing(&mut final_ir, pr_config, sample_rate);
}
let mut renorm_spectrum: Vec<Complex64> = vec![Complex64::new(0.0, 0.0); fft_size];
for (i, &v) in final_ir.iter().enumerate() {
renorm_spectrum[i] = Complex64::new(v, 0.0);
}
let fft = planner.plan_fft_forward(fft_size);
fft.process(&mut renorm_spectrum);
for bin in renorm_spectrum.iter_mut() {
let mag = bin.norm();
if mag > 1e-12 {
*bin /= mag;
}
}
let ifft2 = planner.plan_fft_inverse(fft_size);
ifft2.process(&mut renorm_spectrum);
let inv = 1.0 / fft_size as f64;
let mut renorm_ir = vec![0.0; n_taps];
for (i, val) in renorm_ir.iter_mut().enumerate() {
*val = renorm_spectrum[i].re * inv;
}
renorm_ir
}
fn interpolate_phase_log_space(
src_freqs: &[f64],
src_phase: &[f64],
target_freqs: &[f64],
) -> Vec<f64> {
let n_src = src_freqs.len();
if n_src == 0 {
return vec![0.0; target_freqs.len()];
}
if n_src == 1 {
return vec![src_phase[0]; target_freqs.len()];
}
let src_log: Vec<f64> = src_freqs.iter().map(|&f| f.max(1.0).log2()).collect();
target_freqs
.iter()
.map(|&f| {
let f_log = f.max(1.0).log2();
if f_log <= src_log[0] {
return src_phase[0];
}
if f_log >= src_log[n_src - 1] {
return src_phase[n_src - 1];
}
let idx = src_log.partition_point(|&x| x < f_log);
let idx = idx.min(n_src - 1).max(1);
let t = (f_log - src_log[idx - 1]) / (src_log[idx] - src_log[idx - 1]);
src_phase[idx - 1] + t * (src_phase[idx] - src_phase[idx - 1])
})
.collect()
}
fn smooth_phase_log_freq(
phase: &Array1<f64>,
freq: &Array1<f64>,
width_octaves: f64,
) -> Array1<f64> {
let len = phase.len();
let half_width = width_octaves / 2.0;
let mut smoothed = Array1::zeros(len);
for i in 0..len {
let center_log = freq[i].log2();
let low_log = center_log - half_width;
let high_log = center_log + half_width;
let mut sum = 0.0;
let mut count = 0.0;
for j in 0..len {
let f_log = freq[j].log2();
if f_log >= low_log && f_log <= high_log {
sum += phase[j];
count += 1.0;
}
}
smoothed[i] = if count > 0.0 { sum / count } else { phase[i] };
}
smoothed
}
#[cfg(test)]
mod tests {
use super::*;
fn make_curve_with_phase(freq: Vec<f64>, spl: Vec<f64>, phase: Vec<f64>) -> Curve {
Curve {
freq: Array1::from_vec(freq),
spl: Array1::from_vec(spl),
phase: Some(Array1::from_vec(phase)),
}
}
#[test]
fn test_decompose_phase_requires_phase_data() {
let curve = Curve {
freq: Array1::from_vec(vec![100.0, 1000.0]),
spl: Array1::from_vec(vec![80.0, 80.0]),
phase: None,
};
let config = MixedPhaseConfig::default();
let result = decompose_phase(&curve, &config);
assert!(result.is_err());
}
#[test]
fn test_decompose_phase_flat_measurement() {
let n = 64;
let freq: Vec<f64> = (0..n)
.map(|i| 20.0 * (20000.0 / 20.0_f64).powf(i as f64 / (n - 1) as f64))
.collect();
let spl = vec![80.0; n];
let phase = vec![0.0; n];
let curve = make_curve_with_phase(freq, spl, phase);
let config = MixedPhaseConfig::default();
let result = decompose_phase(&curve, &config);
assert!(result.is_ok());
let (min_phase, _excess, delay_ms, residual) = result.unwrap();
assert!(min_phase.len() == n);
assert!(
delay_ms.abs() < 5.0,
"delay should be small for flat response, got {:.2} ms",
delay_ms
);
let max_residual = residual.iter().map(|r| r.abs()).fold(0.0_f64, f64::max);
assert!(max_residual < 180.0, "residual should be bounded");
}
#[test]
fn test_generate_excess_phase_fir_produces_valid_output() {
let n = 32;
let freq = Array1::linspace(20.0, 20000.0, n);
let residual_phase = Array1::from_elem(n, 5.0);
let config = MixedPhaseConfig::default();
let fir = generate_excess_phase_fir(&freq, &residual_phase, &config, 48000.0);
assert!(!fir.is_empty(), "FIR should not be empty");
assert!(fir.len() >= 31, "FIR should have minimum length");
assert!(
fir.iter().any(|&x| x.abs() > 1e-10),
"FIR should have non-zero taps"
);
}
#[test]
fn test_interpolate_phase_log_space() {
let src_freqs = vec![100.0, 1000.0, 10000.0];
let src_phase = vec![0.0, -45.0, -90.0];
let result = interpolate_phase_log_space(&src_freqs, &src_phase, &src_freqs);
assert!((result[0] - 0.0).abs() < 0.1);
assert!((result[1] - (-45.0)).abs() < 0.1);
assert!((result[2] - (-90.0)).abs() < 0.1);
let mid = interpolate_phase_log_space(&src_freqs, &src_phase, &[316.0]);
assert!(
(mid[0] - (-22.5)).abs() < 1.0,
"expected ~-22.5, got {:.1}",
mid[0]
);
}
#[test]
fn test_phase_only_fir_near_unity_magnitude() {
use num_complex::Complex64;
let freqs = vec![20.0, 100.0, 1000.0, 10000.0, 20000.0];
let magnitude_db = vec![0.0; 5];
let phase_deg = vec![0.0, -10.0, -30.0, -20.0, -5.0];
let config = FirDesignConfig {
n_taps: 511,
sample_rate: 48000.0,
pre_ringing: None,
..Default::default()
};
let fir = generate_phase_only_fir(&freqs, &magnitude_db, &phase_deg, &config);
assert_eq!(fir.len(), 511);
let test_freqs: Vec<f64> = (0..50)
.map(|i| 20.0 * (20000.0 / 20.0_f64).powf(i as f64 / 49.0))
.collect();
let sr = 48000.0;
let mut max_deviation_db: f64 = 0.0;
for &f in &test_freqs {
let w = 2.0 * std::f64::consts::PI * f / sr;
let mut h = Complex64::new(0.0, 0.0);
for (n, &val) in fir.iter().enumerate() {
let angle = -w * n as f64;
h += Complex64::from_polar(val, angle);
}
let mag_db = 20.0 * h.norm().log10();
max_deviation_db = max_deviation_db.max(mag_db.abs());
}
assert!(
max_deviation_db < 0.5,
"magnitude deviation should be < 0.5 dB, got {:.2} dB",
max_deviation_db,
);
}
#[test]
fn test_phase_only_fir_zero_phase_is_near_impulse() {
let freqs = vec![20.0, 100.0, 1000.0, 10000.0, 20000.0];
let magnitude_db = vec![0.0; 5];
let phase_deg = vec![0.0; 5];
let config = FirDesignConfig {
n_taps: 255,
sample_rate: 48000.0,
pre_ringing: None,
..Default::default()
};
let fir = generate_phase_only_fir(&freqs, &magnitude_db, &phase_deg, &config);
assert_eq!(fir.len(), 255);
let center = 255 / 2;
let center_energy = fir[center].abs();
let off_center_max = fir
.iter()
.enumerate()
.filter(|&(i, _)| i != center)
.map(|(_, v)| v.abs())
.fold(0.0_f64, f64::max);
assert!(
center_energy > off_center_max * 2.0,
"center tap ({:.4}) should dominate off-center max ({:.4})",
center_energy,
off_center_max
);
}
#[test]
fn test_phase_only_fir_produces_real_output() {
let freqs = vec![20.0, 100.0, 1000.0, 10000.0, 20000.0];
let phase_deg = vec![0.0, -30.0, -60.0, -30.0, 0.0];
let config = FirDesignConfig {
n_taps: 127,
sample_rate: 48000.0,
pre_ringing: None,
..Default::default()
};
let fir = generate_phase_only_fir(&freqs, &[0.0; 5], &phase_deg, &config);
for (i, &v) in fir.iter().enumerate() {
assert!(v.is_finite(), "tap {} should be finite, got {}", i, v);
}
}
#[test]
fn test_decompose_phase_with_delay() {
let n = 128;
let freq: Vec<f64> = (0..n)
.map(|i| 20.0 * (20000.0 / 20.0_f64).powf(i as f64 / (n - 1) as f64))
.collect();
let spl = vec![80.0; n];
let delay_ms = 2.0;
let delay_s = delay_ms / 1000.0;
let phase: Vec<f64> = freq.iter().map(|&f| -360.0 * f * delay_s).collect();
let curve = make_curve_with_phase(freq, spl, phase);
let config = MixedPhaseConfig {
phase_smoothing_octaves: 0.0,
..Default::default()
};
let result = decompose_phase(&curve, &config);
assert!(result.is_ok());
let (_, _, estimated_delay, _) = result.unwrap();
assert!(
estimated_delay > 0.0 && estimated_delay < delay_ms * 3.0,
"should recover positive delay roughly near {:.1} ms, got {:.2} ms",
delay_ms,
estimated_delay
);
}
#[test]
#[should_panic(expected = "correction_depth length")]
fn test_depth_mask_length_mismatch_panics() {
let n = 32;
let freq = Array1::linspace(20.0, 20000.0, n);
let residual_phase = Array1::from_elem(n, 5.0);
let bad_depth = Array1::from_elem(n / 2, 0.8); let config = MixedPhaseConfig::default();
generate_excess_phase_fir_with_depth(
&freq,
&residual_phase,
&config,
48000.0,
Some(&bad_depth),
);
}
#[test]
fn test_depth_mask_zeros_low_depth_frequencies() {
let n = 32;
let freq = Array1::linspace(20.0, 20000.0, n);
let residual_phase = Array1::from_elem(n, 30.0);
let low_depth = Array1::from_elem(n, 0.1);
let config = MixedPhaseConfig {
min_spatial_depth: 0.5,
..Default::default()
};
let fir_masked = generate_excess_phase_fir_with_depth(
&freq,
&residual_phase,
&config,
48000.0,
Some(&low_depth),
);
let fir_unmasked =
generate_excess_phase_fir_with_depth(&freq, &residual_phase, &config, 48000.0, None);
let center = fir_masked.len() / 2;
let masked_center_ratio =
fir_masked[center].abs() / fir_masked.iter().map(|x| x.abs()).sum::<f64>().max(1e-12);
let unmasked_center_ratio = fir_unmasked[center].abs()
/ fir_unmasked.iter().map(|x| x.abs()).sum::<f64>().max(1e-12);
assert!(
masked_center_ratio > unmasked_center_ratio,
"masked FIR center ratio ({:.4}) should be more concentrated than unmasked ({:.4})",
masked_center_ratio,
unmasked_center_ratio,
);
}
}