use crate::roomeq::types::{MixedModeConfig, ProcessingMode};
use math_audio_iir_fir::{Biquad, BiquadFilterType};
use math_audio_optimisation::{DEConfigBuilder, differential_evolution};
use ndarray::Array1;
use num_complex::Complex64;
use std::f64::consts::PI;
#[derive(Debug, Clone)]
pub struct GdOptConfig {
pub sample_rate: f64,
pub max_delay_ms: f64,
pub ap_per_channel: usize,
pub ap_min_freq: f64,
pub ap_max_freq: f64,
pub ap_min_q: f64,
pub ap_max_q: f64,
pub optimize_polarity: bool,
pub max_iter: usize,
pub popsize: usize,
pub tol: f64,
pub seed: Option<u64>,
}
impl Default for GdOptConfig {
fn default() -> Self {
Self {
sample_rate: 48000.0,
max_delay_ms: 20.0,
ap_per_channel: 2,
ap_min_freq: 20.0,
ap_max_freq: 300.0,
ap_min_q: 0.3,
ap_max_q: 10.0,
optimize_polarity: true,
max_iter: 2000,
popsize: 20,
tol: 1e-8,
seed: None,
}
}
}
#[derive(Debug, Clone)]
pub struct ChannelGdResult {
pub delay_ms: f64,
pub polarity_inverted: bool,
pub ap_filters: Vec<Biquad>,
pub channel_gd_pre_rms_ms: f64,
pub channel_gd_post_rms_ms: f64,
}
#[derive(Debug, Clone)]
pub struct GroupDelayOptResult {
pub band: (f64, f64),
pub per_channel: Vec<ChannelGdResult>,
pub sum_gd_pre_rms_ms: f64,
pub sum_gd_post_rms_ms: f64,
pub mean_coherence: f64,
pub improvement_db: f64,
}
#[derive(Debug, Clone)]
pub struct ChannelMeasurementInput {
pub freq: Array1<f64>,
pub spl: Array1<f64>,
pub phase: Array1<f64>,
pub coherence: Array1<f64>,
}
#[derive(Debug, Clone)]
pub struct GdAlignmentTarget {
pub per_channel_delay_ms: Vec<f64>,
pub sum_gd_reference_ms: Vec<f64>,
pub freq: Array1<f64>,
}
pub fn build_gd_alignment_target(
channels: &[ChannelMeasurementInput],
result: &GroupDelayOptResult,
config: &GdOptConfig,
) -> GdAlignmentTarget {
let n_freq = channels[0].freq.len();
let band_indices: Vec<usize> = (0..n_freq)
.filter(|&i| channels[0].freq[i] >= result.band.0 && channels[0].freq[i] <= result.band.1)
.collect();
let params = encode_result_as_params(result, config);
let sum_gd = compute_sum_gd(channels, ¶ms, &band_indices, config);
let per_channel_delay_ms = result.per_channel.iter().map(|ch| ch.delay_ms).collect();
let freq = Array1::from_iter(band_indices.iter().map(|&i| channels[0].freq[i]));
GdAlignmentTarget {
per_channel_delay_ms,
sum_gd_reference_ms: sum_gd,
freq,
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum GdOptAdvisory {
Success { improvement_db: f64 },
NoPhaseData,
CoherenceBelowThreshold { mean_coherence: f64 },
PhaseLinearNoTarget,
InsufficientChannels,
EmptyBand,
MinimalImprovement { improvement_db: f64 },
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize, schemars::JsonSchema)]
pub struct GroupDelayOptSummary {
pub band: (f64, f64),
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub channel_names: Vec<String>,
pub per_channel_delay_ms: Vec<f64>,
pub per_channel_polarity_inverted: Vec<bool>,
pub per_channel_ap_count: Vec<usize>,
pub sum_gd_pre_rms_ms: f64,
pub sum_gd_post_rms_ms: f64,
pub mean_coherence: f64,
pub improvement_db: f64,
pub advisory: String,
}
impl GroupDelayOptSummary {
pub fn from_result_with_names(result: &GroupDelayOptResult, names: Vec<String>) -> Self {
Self {
band: result.band,
channel_names: names,
per_channel_delay_ms: result.per_channel.iter().map(|ch| ch.delay_ms).collect(),
per_channel_polarity_inverted: result
.per_channel
.iter()
.map(|ch| ch.polarity_inverted)
.collect(),
per_channel_ap_count: result
.per_channel
.iter()
.map(|ch| ch.ap_filters.len())
.collect(),
sum_gd_pre_rms_ms: result.sum_gd_pre_rms_ms,
sum_gd_post_rms_ms: result.sum_gd_post_rms_ms,
mean_coherence: result.mean_coherence,
improvement_db: result.improvement_db,
advisory: "success".to_string(),
}
}
pub fn from_advisory(advisory: &GdOptAdvisory) -> Self {
let reason = match advisory {
GdOptAdvisory::Success { improvement_db } => {
format!("success:{improvement_db:.1}dB")
}
GdOptAdvisory::NoPhaseData => "no_phase_data".to_string(),
GdOptAdvisory::CoherenceBelowThreshold { mean_coherence } => {
format!("coherence_below_threshold:{mean_coherence:.2}")
}
GdOptAdvisory::PhaseLinearNoTarget => "phase_linear_no_target".to_string(),
GdOptAdvisory::InsufficientChannels => "insufficient_channels".to_string(),
GdOptAdvisory::EmptyBand => "empty_band".to_string(),
GdOptAdvisory::MinimalImprovement { improvement_db } => {
format!("minimal_improvement:{improvement_db:.1}dB")
}
};
Self {
band: (0.0, 0.0),
channel_names: vec![],
per_channel_delay_ms: vec![],
per_channel_polarity_inverted: vec![],
per_channel_ap_count: vec![],
sum_gd_pre_rms_ms: 0.0,
sum_gd_post_rms_ms: 0.0,
mean_coherence: 0.0,
improvement_db: 0.0,
advisory: reason,
}
}
}
pub fn derive_band(min_freq: f64, crossover_freq: f64) -> (f64, f64) {
let band_lo = min_freq.max(crossover_freq * 0.25);
let band_hi = crossover_freq * 2.0;
(band_lo, band_hi)
}
pub fn optimize_group_delay(
channels: &[ChannelMeasurementInput],
band: (f64, f64),
config: &GdOptConfig,
) -> Result<GroupDelayOptResult, String> {
let n_ch = channels.len();
if n_ch < 2 {
return Err("GD-Opt requires at least 2 channels".into());
}
let n_freq = channels[0].freq.len();
for (i, ch) in channels.iter().enumerate() {
if ch.freq.len() != n_freq || ch.spl.len() != n_freq || ch.phase.len() != n_freq {
return Err(format!("Channel {} has inconsistent array lengths", i));
}
}
let band_indices: Vec<usize> = (0..n_freq)
.filter(|&i| channels[0].freq[i] >= band.0 && channels[0].freq[i] <= band.1)
.collect();
if band_indices.is_empty() {
return Err("No frequency bins within the specified band".into());
}
let mean_coherence = compute_mean_coherence(channels, &band_indices);
let identity_params = vec![0.0; param_count(n_ch, config)];
let sum_gd_pre_rms_ms =
compute_sum_gd_rms(channels, &identity_params, &band_indices, config);
let bounds = build_bounds(n_ch, config);
let de_config = {
let mut builder = DEConfigBuilder::new()
.maxiter(config.max_iter)
.popsize(config.popsize)
.tol(config.tol);
if let Some(seed) = config.seed {
builder = builder.seed(seed);
}
builder.build().map_err(|e| format!("DE config error: {e}"))?
};
let channels_ref = channels;
let band_indices_ref = &band_indices;
let config_ref = config;
let loss_fn = |x: &Array1<f64>| -> f64 {
gd_loss(channels_ref, x.as_slice().unwrap(), band_indices_ref, config_ref)
};
let report = differential_evolution(&loss_fn, &bounds, de_config)
.map_err(|e| format!("DE failed: {e}"))?;
let best_params = report.x.as_slice().unwrap();
let sum_gd_post_rms_ms =
compute_sum_gd_rms(channels, best_params, &band_indices, config);
let improvement_db = if sum_gd_pre_rms_ms < 1e-15 {
0.0 } else if sum_gd_post_rms_ms > 1e-15 {
20.0 * (sum_gd_pre_rms_ms / sum_gd_post_rms_ms).log10()
} else {
120.0 };
let per_channel = decode_per_channel(channels, best_params, &band_indices, config);
Ok(GroupDelayOptResult {
band,
per_channel,
sum_gd_pre_rms_ms,
sum_gd_post_rms_ms,
mean_coherence,
improvement_db,
})
}
const MAX_AP_BUDGET: usize = 2;
const BOOTSTRAP_SIGMA_THRESHOLD: f64 = 3.0;
pub fn optimize_group_delay_adaptive(
channels: &[ChannelMeasurementInput],
sweep_realisations: &[Vec<ChannelMeasurementInput>],
band: (f64, f64),
config: &GdOptConfig,
) -> Result<GroupDelayOptResult, String> {
if sweep_realisations.len() < 2 {
return Err(
"Adaptive AP bootstrap requires at least 2 sweep realisations (N >= 2)".into(),
);
}
let mut best_config = GdOptConfig {
ap_per_channel: 0,
..config.clone()
};
let mut best_result = optimize_group_delay(channels, band, &best_config)?;
for k in 1..=MAX_AP_BUDGET {
let trial_config = GdOptConfig {
ap_per_channel: k,
..config.clone()
};
let trial_result = optimize_group_delay(channels, band, &trial_config)?;
let improvements = compute_bootstrap_improvements(
sweep_realisations,
band,
&best_config,
&best_result,
&trial_config,
&trial_result,
)?;
let n = improvements.len() as f64;
let mean_improvement = improvements.iter().sum::<f64>() / n;
let variance = improvements
.iter()
.map(|&x| (x - mean_improvement).powi(2))
.sum::<f64>()
/ (n - 1.0);
let sigma = variance.sqrt();
let significant =
sigma > 1e-15 && (mean_improvement / sigma) > BOOTSTRAP_SIGMA_THRESHOLD;
if significant && trial_result.sum_gd_post_rms_ms < best_result.sum_gd_post_rms_ms {
best_result = trial_result;
best_config = trial_config;
} else {
break;
}
}
Ok(best_result)
}
fn compute_bootstrap_improvements(
sweep_realisations: &[Vec<ChannelMeasurementInput>],
band: (f64, f64),
baseline_config: &GdOptConfig,
baseline_result: &GroupDelayOptResult,
trial_config: &GdOptConfig,
trial_result: &GroupDelayOptResult,
) -> Result<Vec<f64>, String> {
let mut improvements = Vec::with_capacity(sweep_realisations.len());
for realisation in sweep_realisations {
if realisation.len() != baseline_result.per_channel.len() {
return Err("Sweep realisation channel count mismatch".into());
}
let n_freq = realisation[0].freq.len();
let band_indices: Vec<usize> = (0..n_freq)
.filter(|&i| realisation[0].freq[i] >= band.0 && realisation[0].freq[i] <= band.1)
.collect();
if band_indices.is_empty() {
improvements.push(0.0);
continue;
}
let baseline_params = encode_result_as_params(baseline_result, baseline_config);
let trial_params = encode_result_as_params(trial_result, trial_config);
let rms_baseline =
compute_sum_gd_rms(realisation, &baseline_params, &band_indices, baseline_config);
let rms_trial =
compute_sum_gd_rms(realisation, &trial_params, &band_indices, trial_config);
improvements.push(rms_baseline - rms_trial);
}
Ok(improvements)
}
fn encode_result_as_params(result: &GroupDelayOptResult, config: &GdOptConfig) -> Vec<f64> {
let n_ch = result.per_channel.len();
let per_ch = 1 + config.ap_per_channel * 2 + if config.optimize_polarity { 1 } else { 0 };
let mut params = vec![0.0; n_ch * per_ch];
for (ch_idx, ch_result) in result.per_channel.iter().enumerate() {
let offset = ch_idx * per_ch;
params[offset] = ch_result.delay_ms;
for (i, ap) in ch_result.ap_filters.iter().enumerate() {
if i < config.ap_per_channel {
params[offset + 1 + i * 2] = ap.freq;
params[offset + 1 + i * 2 + 1] = ap.q;
}
}
if config.optimize_polarity {
params[offset + 1 + config.ap_per_channel * 2] =
if ch_result.polarity_inverted { 1.0 } else { 0.0 };
}
}
params
}
pub fn optimize_group_delay_for_mode(
channels: &[ChannelMeasurementInput],
band: (f64, f64),
config: &GdOptConfig,
processing_mode: &ProcessingMode,
mixed_mode_config: Option<&MixedModeConfig>,
) -> Result<GroupDelayOptResult, String> {
match processing_mode {
ProcessingMode::LowLatency | ProcessingMode::WarpedIir | ProcessingMode::KautzModal => {
optimize_group_delay(channels, band, config)
}
ProcessingMode::Hybrid => {
let xo_freq = mixed_mode_config
.map(|m| m.crossover_freq)
.unwrap_or(300.0);
if band.1 > xo_freq {
return Err(format!(
"Hybrid mode: GD-Opt band_hi ({:.1} Hz) exceeds mixed_config crossover \
({:.1} Hz). AP filters must stay in the IIR band.",
band.1, xo_freq,
));
}
optimize_group_delay(channels, band, config)
}
ProcessingMode::MixedPhase => {
let mixed_phase_config = GdOptConfig {
ap_per_channel: config.ap_per_channel.min(1),
..config.clone()
};
optimize_group_delay(channels, band, &mixed_phase_config)
}
ProcessingMode::PhaseLinear => Err(
"PhaseLinear mode does not use IIR AP filters. \
Use the FIR path (GD-3b) with GdAlignmentTarget instead."
.into(),
),
}
}
fn param_count(n_ch: usize, config: &GdOptConfig) -> usize {
let per_ch = 1 + config.ap_per_channel * 2 + if config.optimize_polarity { 1 } else { 0 };
n_ch * per_ch
}
fn build_bounds(n_ch: usize, config: &GdOptConfig) -> Vec<(f64, f64)> {
let mut bounds = Vec::new();
for _ in 0..n_ch {
bounds.push((0.0, config.max_delay_ms));
for _ in 0..config.ap_per_channel {
bounds.push((config.ap_min_freq, config.ap_max_freq));
bounds.push((config.ap_min_q, config.ap_max_q));
}
if config.optimize_polarity {
bounds.push((0.0, 1.0));
}
}
bounds
}
struct ChannelParams {
delay_ms: f64,
ap_filters: Vec<(f64, f64)>, polarity_inverted: bool,
}
fn decode_channel_params(params: &[f64], ch: usize, config: &GdOptConfig) -> ChannelParams {
let per_ch = 1 + config.ap_per_channel * 2 + if config.optimize_polarity { 1 } else { 0 };
let offset = ch * per_ch;
let delay_ms = params[offset];
let mut ap_filters = Vec::with_capacity(config.ap_per_channel);
for i in 0..config.ap_per_channel {
let freq = params[offset + 1 + i * 2];
let q = params[offset + 1 + i * 2 + 1];
ap_filters.push((freq, q));
}
let polarity_inverted = if config.optimize_polarity {
params[offset + 1 + config.ap_per_channel * 2] > 0.5
} else {
false
};
ChannelParams {
delay_ms,
ap_filters,
polarity_inverted,
}
}
fn channel_complex_at(
ch: &ChannelMeasurementInput,
freq_idx: usize,
ch_params: &ChannelParams,
config: &GdOptConfig,
) -> Complex64 {
let f = ch.freq[freq_idx];
let omega = 2.0 * PI * f;
let mag = 10.0_f64.powf(ch.spl[freq_idx] / 20.0);
let phase = ch.phase[freq_idx];
let mut h = Complex64::from_polar(mag, phase);
let delay_s = ch_params.delay_ms * 1e-3;
h *= Complex64::from_polar(1.0, -omega * delay_s);
for &(ap_freq, ap_q) in &ch_params.ap_filters {
let ap = Biquad::new(BiquadFilterType::AllPass, ap_freq, config.sample_rate, ap_q, 0.0);
h *= ap.complex_response(f);
}
if ch_params.polarity_inverted {
h = -h;
}
h
}
fn compute_sum_gd(
channels: &[ChannelMeasurementInput],
params: &[f64],
band_indices: &[usize],
config: &GdOptConfig,
) -> Vec<f64> {
let ch_params: Vec<ChannelParams> = (0..channels.len())
.map(|ch_idx| decode_channel_params(params, ch_idx, config))
.collect();
let mut gd_ms = Vec::with_capacity(band_indices.len());
for (bi, &idx) in band_indices.iter().enumerate() {
let idx_next = if bi + 1 < band_indices.len() {
band_indices[bi + 1]
} else if idx + 1 < channels[0].freq.len() {
idx + 1
} else {
if !gd_ms.is_empty() {
gd_ms.push(*gd_ms.last().unwrap());
} else {
gd_ms.push(0.0);
}
continue;
};
let f0 = channels[0].freq[idx];
let f1 = channels[0].freq[idx_next];
let omega0 = 2.0 * PI * f0;
let omega1 = 2.0 * PI * f1;
let mut sum0 = Complex64::new(0.0, 0.0);
let mut sum1 = Complex64::new(0.0, 0.0);
for (ch, cp) in channels.iter().zip(ch_params.iter()) {
sum0 += channel_complex_at(ch, idx, cp, config);
sum1 += channel_complex_at(ch, idx_next, cp, config);
}
let phase0 = sum0.arg();
let phase1 = sum1.arg();
let d_phase = unwrap_phase_diff(phase1 - phase0);
let d_omega = omega1 - omega0;
let gd_s = if d_omega.abs() > 1e-15 {
-d_phase / d_omega
} else {
0.0
};
gd_ms.push(gd_s * 1000.0);
}
gd_ms
}
fn compute_sum_gd_rms(
channels: &[ChannelMeasurementInput],
params: &[f64],
band_indices: &[usize],
config: &GdOptConfig,
) -> f64 {
let gd = compute_sum_gd(channels, params, band_indices, config);
if gd.is_empty() {
return 0.0;
}
let weights: Vec<f64> = band_indices
.iter()
.map(|&idx| {
let mean_coh: f64 = channels.iter().map(|ch| ch.coherence[idx]).sum::<f64>()
/ channels.len() as f64;
mean_coh * mean_coh })
.collect();
let mut gd_sorted = gd.clone();
gd_sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let target = gd_sorted[gd_sorted.len() / 2];
let mut weighted_sum = 0.0;
let mut weight_total = 0.0;
for (i, &g) in gd.iter().enumerate() {
let w = weights[i];
let dev = g - target;
weighted_sum += w * dev * dev;
weight_total += w;
}
if weight_total > 0.0 {
(weighted_sum / weight_total).sqrt()
} else {
0.0
}
}
fn gd_loss(
channels: &[ChannelMeasurementInput],
params: &[f64],
band_indices: &[usize],
config: &GdOptConfig,
) -> f64 {
compute_sum_gd_rms(channels, params, band_indices, config)
}
fn compute_mean_coherence(channels: &[ChannelMeasurementInput], band_indices: &[usize]) -> f64 {
if band_indices.is_empty() || channels.is_empty() {
return 0.0;
}
let mut sum = 0.0;
let mut count = 0;
for ch in channels {
for &idx in band_indices {
sum += ch.coherence[idx];
count += 1;
}
}
sum / count as f64
}
fn unwrap_phase_diff(mut d: f64) -> f64 {
while d > PI {
d -= 2.0 * PI;
}
while d < -PI {
d += 2.0 * PI;
}
d
}
fn decode_per_channel(
channels: &[ChannelMeasurementInput],
params: &[f64],
band_indices: &[usize],
config: &GdOptConfig,
) -> Vec<ChannelGdResult> {
let n_ch = channels.len();
let mut results = Vec::with_capacity(n_ch);
for ch_idx in 0..n_ch {
let cp = decode_channel_params(params, ch_idx, config);
let ap_filters: Vec<Biquad> = cp
.ap_filters
.iter()
.map(|&(freq, q)| {
Biquad::new(BiquadFilterType::AllPass, freq, config.sample_rate, q, 0.0)
})
.collect();
let single_ch = &channels[ch_idx..ch_idx + 1];
let id_1ch = vec![0.0; param_count(1, config)];
let pre_rms = compute_sum_gd_rms(single_ch, &id_1ch, band_indices, config);
let per_ch_size =
1 + config.ap_per_channel * 2 + if config.optimize_polarity { 1 } else { 0 };
let ch_offset = ch_idx * per_ch_size;
let post_params_1ch = params[ch_offset..ch_offset + per_ch_size].to_vec();
let post_rms = compute_sum_gd_rms(single_ch, &post_params_1ch, band_indices, config);
results.push(ChannelGdResult {
delay_ms: cp.delay_ms,
polarity_inverted: cp.polarity_inverted,
ap_filters,
channel_gd_pre_rms_ms: pre_rms,
channel_gd_post_rms_ms: post_rms,
});
}
results
}
#[cfg(test)]
mod tests {
use super::*;
fn make_delayed_channel(
freq_grid: &Array1<f64>,
delay_ms: f64,
coherence: f64,
) -> ChannelMeasurementInput {
let n = freq_grid.len();
let spl = Array1::zeros(n); let delay_s = delay_ms * 1e-3;
let phase = freq_grid.mapv(|f| -2.0 * PI * f * delay_s);
let coherence = Array1::from_elem(n, coherence);
ChannelMeasurementInput {
freq: freq_grid.clone(),
spl,
phase,
coherence,
}
}
fn log_freq_grid(f_min: f64, f_max: f64, n_points: usize) -> Array1<f64> {
let log_min = f_min.ln();
let log_max = f_max.ln();
Array1::from_iter((0..n_points).map(|i| {
let t = i as f64 / (n_points - 1) as f64;
(log_min + t * (log_max - log_min)).exp()
}))
}
#[test]
fn test_derive_band() {
let (lo, hi) = derive_band(20.0, 80.0);
assert!((lo - 20.0).abs() < 1e-10);
assert!((hi - 160.0).abs() < 1e-10);
let (lo2, hi2) = derive_band(30.0, 80.0);
assert!((lo2 - 30.0).abs() < 1e-10); assert!((hi2 - 160.0).abs() < 1e-10);
}
#[test]
fn test_two_channel_delay_recovery() {
let freq = log_freq_grid(20.0, 5000.0, 500);
let ch0 = make_delayed_channel(&freq, 2.0, 0.98);
let ch1 = make_delayed_channel(&freq, 4.0, 0.98);
let channels = vec![ch0, ch1];
let band = (20.0, 5000.0);
let config = GdOptConfig {
sample_rate: 48000.0,
max_delay_ms: 10.0,
ap_per_channel: 0, optimize_polarity: false,
max_iter: 5000,
popsize: 30,
tol: 1e-12,
seed: Some(42),
..Default::default()
};
let result = optimize_group_delay(&channels, band, &config).unwrap();
let d0 = result.per_channel[0].delay_ms;
let d1 = result.per_channel[1].delay_ms;
let effective_delay_0 = 2.0 + d0;
let effective_delay_1 = 4.0 + d1;
let residual_diff = (effective_delay_0 - effective_delay_1).abs();
assert!(
residual_diff < 0.1,
"Delay recovery failed: residual difference = {:.3} ms (expected < 0.1 ms). \
d0={:.3}, d1={:.3}, effective: {:.3} vs {:.3}",
residual_diff,
d0,
d1,
effective_delay_0,
effective_delay_1,
);
assert!(
result.improvement_db >= 6.0,
"Improvement too low: {:.1} dB (expected >= 6.0 dB). \
pre_rms={:.3} ms, post_rms={:.3} ms",
result.improvement_db,
result.sum_gd_pre_rms_ms,
result.sum_gd_post_rms_ms,
);
}
#[test]
fn test_band_derivation_respects_min_freq() {
let (lo, _) = derive_band(50.0, 100.0);
assert!((lo - 50.0).abs() < 1e-10);
let (lo2, _) = derive_band(10.0, 100.0);
assert!((lo2 - 25.0).abs() < 1e-10);
}
#[test]
fn test_coherence_weighting() {
let freq = log_freq_grid(20.0, 300.0, 100);
let ch0 = make_delayed_channel(&freq, 0.0, 0.95);
let n = freq.len();
let spl = Array1::zeros(n);
let delay_s = 10.0e-3;
let phase = freq.mapv(|f| -2.0 * PI * f * delay_s);
let mut coherence = Array1::from_elem(n, 0.95);
for i in 0..n / 2 {
coherence[i] = 0.1;
}
let ch1 = ChannelMeasurementInput {
freq: freq.clone(),
spl,
phase,
coherence,
};
let channels = vec![ch0, ch1];
let band_indices: Vec<usize> = (0..n).collect();
let identity = vec![0.0; param_count(2, &GdOptConfig::default())];
let rms = compute_sum_gd_rms(
&channels,
&identity,
&band_indices,
&GdOptConfig::default(),
);
assert!(rms > 0.0, "RMS should be non-zero with delay mismatch");
let ch1_high_coh = make_delayed_channel(&freq, 10.0, 0.95);
let channels_high_coh = vec![make_delayed_channel(&freq, 0.0, 0.95), ch1_high_coh];
let rms_high = compute_sum_gd_rms(
&channels_high_coh,
&identity,
&band_indices,
&GdOptConfig::default(),
);
assert!(
rms_high > rms,
"High-coherence RMS ({:.3}) should exceed low-coherence RMS ({:.3})",
rms_high,
rms,
);
}
#[test]
fn test_minimum_channels() {
let freq = log_freq_grid(20.0, 300.0, 50);
let ch0 = make_delayed_channel(&freq, 0.0, 0.95);
let result = optimize_group_delay(&[ch0], (20.0, 300.0), &GdOptConfig::default());
assert!(result.is_err());
assert!(result.unwrap_err().contains("at least 2 channels"));
}
#[test]
fn test_adaptive_bootstrap_rejects_noisy_ap() {
let freq = log_freq_grid(20.0, 5000.0, 300);
let ch0 = make_delayed_channel(&freq, 2.0, 0.95);
let ch1 = make_delayed_channel(&freq, 4.0, 0.95);
let channels = vec![ch0, ch1];
let sweep_realisations: Vec<Vec<ChannelMeasurementInput>> = (0..4)
.map(|seed| {
let jitter = (seed as f64 * 0.1 + 0.05) * 1e-3; vec![
make_delayed_channel(&freq, 2.0 + jitter, 0.95),
make_delayed_channel(&freq, 4.0 - jitter, 0.95),
]
})
.collect();
let config = GdOptConfig {
sample_rate: 48000.0,
max_delay_ms: 10.0,
ap_per_channel: 2, optimize_polarity: false,
max_iter: 2000,
popsize: 20,
tol: 1e-10,
seed: Some(123),
..Default::default()
};
let result =
optimize_group_delay_adaptive(&channels, &sweep_realisations, (20.0, 5000.0), &config)
.unwrap();
let d0 = result.per_channel[0].delay_ms;
let d1 = result.per_channel[1].delay_ms;
let residual = ((2.0 + d0) - (4.0 + d1)).abs();
assert!(
residual < 0.2,
"Delay alignment failed: residual={:.3}ms",
residual
);
assert!(
result.improvement_db >= 6.0,
"Improvement too low: {:.1} dB",
result.improvement_db
);
}
#[test]
fn test_adaptive_bootstrap_requires_min_sweeps() {
let freq = log_freq_grid(20.0, 300.0, 50);
let ch0 = make_delayed_channel(&freq, 0.0, 0.95);
let ch1 = make_delayed_channel(&freq, 5.0, 0.95);
let channels = vec![ch0, ch1];
let one_sweep = vec![vec![
make_delayed_channel(&freq, 0.0, 0.95),
make_delayed_channel(&freq, 5.0, 0.95),
]];
let config = GdOptConfig::default();
let result =
optimize_group_delay_adaptive(&channels, &one_sweep, (20.0, 300.0), &config);
assert!(result.is_err());
assert!(result.unwrap_err().contains("at least 2"));
}
#[test]
fn test_mode_dispatch_low_latency() {
let freq = log_freq_grid(20.0, 5000.0, 300);
let ch0 = make_delayed_channel(&freq, 1.0, 0.95);
let ch1 = make_delayed_channel(&freq, 3.0, 0.95);
let channels = vec![ch0, ch1];
let config = GdOptConfig {
ap_per_channel: 1,
optimize_polarity: false,
max_iter: 2000,
popsize: 20,
seed: Some(77),
..Default::default()
};
let result = optimize_group_delay_for_mode(
&channels,
(20.0, 5000.0),
&config,
&ProcessingMode::LowLatency,
None,
)
.unwrap();
assert!(result.improvement_db > 0.0);
}
#[test]
fn test_mode_dispatch_hybrid_within_crossover() {
let freq = log_freq_grid(20.0, 200.0, 100);
let ch0 = make_delayed_channel(&freq, 1.0, 0.95);
let ch1 = make_delayed_channel(&freq, 3.0, 0.95);
let channels = vec![ch0, ch1];
let config = GdOptConfig {
ap_per_channel: 0,
optimize_polarity: false,
max_iter: 1000,
popsize: 15,
seed: Some(88),
..Default::default()
};
let mixed_config = MixedModeConfig {
crossover_freq: 300.0,
crossover_type: "LR24".to_string(),
fir_band: "high".to_string(),
};
let result = optimize_group_delay_for_mode(
&channels,
(20.0, 200.0),
&config,
&ProcessingMode::Hybrid,
Some(&mixed_config),
);
assert!(result.is_ok());
}
#[test]
fn test_mode_dispatch_hybrid_exceeds_crossover() {
let freq = log_freq_grid(20.0, 500.0, 100);
let ch0 = make_delayed_channel(&freq, 1.0, 0.95);
let ch1 = make_delayed_channel(&freq, 3.0, 0.95);
let channels = vec![ch0, ch1];
let config = GdOptConfig::default();
let mixed_config = MixedModeConfig {
crossover_freq: 300.0,
crossover_type: "LR24".to_string(),
fir_band: "high".to_string(),
};
let result = optimize_group_delay_for_mode(
&channels,
(20.0, 500.0),
&config,
&ProcessingMode::Hybrid,
Some(&mixed_config),
);
assert!(result.is_err());
assert!(result.unwrap_err().contains("exceeds mixed_config crossover"));
}
#[test]
fn test_mode_dispatch_mixed_phase_caps_ap() {
let freq = log_freq_grid(20.0, 5000.0, 300);
let ch0 = make_delayed_channel(&freq, 1.0, 0.95);
let ch1 = make_delayed_channel(&freq, 3.0, 0.95);
let channels = vec![ch0, ch1];
let config = GdOptConfig {
ap_per_channel: 2, optimize_polarity: false,
max_iter: 2000,
popsize: 20,
seed: Some(99),
..Default::default()
};
let result = optimize_group_delay_for_mode(
&channels,
(20.0, 5000.0),
&config,
&ProcessingMode::MixedPhase,
None,
)
.unwrap();
for ch in &result.per_channel {
assert!(
ch.ap_filters.len() <= 1,
"MixedPhase should cap AP at 1, got {}",
ch.ap_filters.len()
);
}
}
#[test]
fn test_mode_dispatch_phase_linear_rejects() {
let freq = log_freq_grid(20.0, 300.0, 50);
let ch0 = make_delayed_channel(&freq, 0.0, 0.95);
let ch1 = make_delayed_channel(&freq, 5.0, 0.95);
let channels = vec![ch0, ch1];
let result = optimize_group_delay_for_mode(
&channels,
(20.0, 300.0),
&GdOptConfig::default(),
&ProcessingMode::PhaseLinear,
None,
);
assert!(result.is_err());
assert!(result.unwrap_err().contains("PhaseLinear"));
}
#[test]
fn test_mode_dispatch_warped_iir_same_as_low_latency() {
let freq = log_freq_grid(20.0, 5000.0, 300);
let ch0 = make_delayed_channel(&freq, 2.0, 0.95);
let ch1 = make_delayed_channel(&freq, 4.0, 0.95);
let channels = vec![ch0, ch1];
let config = GdOptConfig {
ap_per_channel: 0,
optimize_polarity: false,
max_iter: 3000,
popsize: 25,
tol: 1e-10,
seed: Some(42),
..Default::default()
};
let wi_result = optimize_group_delay_for_mode(
&channels,
(20.0, 5000.0),
&config,
&ProcessingMode::WarpedIir,
None,
)
.unwrap();
let km_result = optimize_group_delay_for_mode(
&channels,
(20.0, 5000.0),
&config,
&ProcessingMode::KautzModal,
None,
)
.unwrap();
assert!(
wi_result.improvement_db >= 6.0,
"WarpedIir improvement too low: {:.1} dB",
wi_result.improvement_db
);
assert!(
km_result.improvement_db >= 6.0,
"KautzModal improvement too low: {:.1} dB",
km_result.improvement_db
);
}
fn make_delayed_channel_with_allpass(
freq_grid: &Array1<f64>,
delay_ms: f64,
ap_freq: f64,
ap_q: f64,
sample_rate: f64,
coherence: f64,
) -> ChannelMeasurementInput {
let n = freq_grid.len();
let spl = Array1::zeros(n); let delay_s = delay_ms * 1e-3;
let ap = Biquad::new(BiquadFilterType::AllPass, ap_freq, sample_rate, ap_q, 0.0);
let phase = freq_grid.mapv(|f| {
let linear_phase = -2.0 * PI * f * delay_s;
let ap_phase = ap.complex_response(f).arg();
linear_phase + ap_phase
});
let coherence = Array1::from_elem(n, coherence);
ChannelMeasurementInput {
freq: freq_grid.clone(),
spl,
phase,
coherence,
}
}
#[test]
fn test_qa_three_channel_lrsub_delay_recovery() {
let freq = log_freq_grid(20.0, 5000.0, 500);
let ch_l = make_delayed_channel(&freq, 1.0, 0.98);
let ch_r = make_delayed_channel(&freq, 3.0, 0.98);
let ch_sub = make_delayed_channel(&freq, 8.0, 0.98);
let channels = vec![ch_l, ch_r, ch_sub];
let band = (20.0, 5000.0);
let config = GdOptConfig {
sample_rate: 48000.0,
max_delay_ms: 15.0,
ap_per_channel: 0,
optimize_polarity: false,
max_iter: 5000,
popsize: 30,
tol: 1e-12,
seed: Some(42),
..Default::default()
};
let result = optimize_group_delay(&channels, band, &config).unwrap();
let meas_delays = [1.0_f64, 3.0, 8.0];
let opt_delays: Vec<f64> = result.per_channel.iter().map(|ch| ch.delay_ms).collect();
for i in 0..3 {
for j in (i + 1)..3 {
let eff_i = meas_delays[i] + opt_delays[i];
let eff_j = meas_delays[j] + opt_delays[j];
let diff = (eff_i - eff_j).abs();
assert!(
diff < 0.15,
"Pairwise effective delay difference (ch{i} vs ch{j}) = {diff:.3} ms \
(expected < 0.15 ms). opt_delays = {opt_delays:?}",
);
}
}
assert!(
result.improvement_db >= 6.0,
"Improvement too low: {:.1} dB (expected >= 6 dB). \
pre_rms={:.3} ms, post_rms={:.3} ms",
result.improvement_db,
result.sum_gd_pre_rms_ms,
result.sum_gd_post_rms_ms,
);
}
#[test]
fn test_qa_two_channel_with_allpass_distortion() {
let freq = log_freq_grid(20.0, 300.0, 400);
let sample_rate = 48000.0;
let ch0 = make_delayed_channel(&freq, 2.0, 0.98);
let ch1 = make_delayed_channel_with_allpass(&freq, 2.0, 60.0, 2.0, sample_rate, 0.98);
let channels = vec![ch0, ch1];
let band = (20.0, 300.0);
let config = GdOptConfig {
sample_rate,
max_delay_ms: 10.0,
ap_per_channel: 2,
ap_min_freq: 20.0,
ap_max_freq: 300.0,
ap_min_q: 0.3,
ap_max_q: 10.0,
optimize_polarity: false,
max_iter: 5000,
popsize: 30,
tol: 1e-12,
seed: Some(7),
};
let result = optimize_group_delay(&channels, band, &config).unwrap();
assert!(
result.improvement_db >= 6.0,
"Improvement too low: {:.1} dB (expected >= 6 dB). \
pre_rms={:.3} ms, post_rms={:.3} ms",
result.improvement_db,
result.sum_gd_pre_rms_ms,
result.sum_gd_post_rms_ms,
);
let any_ap = result.per_channel.iter().any(|ch| !ch.ap_filters.is_empty());
assert!(
any_ap,
"Expected at least one channel to have AP filters; got none. \
ap counts: {:?}",
result
.per_channel
.iter()
.map(|ch| ch.ap_filters.len())
.collect::<Vec<_>>(),
);
}
#[test]
fn test_qa_adaptive_bootstrap_accepts_real_ap() {
let freq = log_freq_grid(20.0, 300.0, 300);
let sample_rate = 48000.0;
let channels = vec![
make_delayed_channel(&freq, 2.0, 0.98),
make_delayed_channel_with_allpass(&freq, 2.0, 60.0, 2.0, sample_rate, 0.98),
];
let sweep_realisations: Vec<Vec<ChannelMeasurementInput>> = (0..4)
.map(|seed| {
let jitter = seed as f64 * 0.02e-3; vec![
make_delayed_channel(&freq, 2.0 + jitter, 0.98),
make_delayed_channel_with_allpass(
&freq,
2.0 + jitter,
60.0,
2.0,
sample_rate,
0.98,
),
]
})
.collect();
let config = GdOptConfig {
sample_rate,
max_delay_ms: 10.0,
ap_per_channel: 2,
ap_min_freq: 20.0,
ap_max_freq: 300.0,
ap_min_q: 0.3,
ap_max_q: 10.0,
optimize_polarity: false,
max_iter: 4000,
popsize: 25,
tol: 1e-10,
seed: Some(11),
};
let result = optimize_group_delay_adaptive(
&channels,
&sweep_realisations,
(20.0, 300.0),
&config,
)
.unwrap();
let total_ap: usize = result
.per_channel
.iter()
.map(|ch| ch.ap_filters.len())
.sum();
assert!(
total_ap >= 1,
"Expected adaptive bootstrap to accept at least 1 AP filter; got 0. \
improvement_db={:.1}",
result.improvement_db,
);
assert!(
result.improvement_db >= 4.0,
"Improvement too low: {:.1} dB (expected >= 4 dB). \
pre_rms={:.3} ms, post_rms={:.3} ms",
result.improvement_db,
result.sum_gd_pre_rms_ms,
result.sum_gd_post_rms_ms,
);
}
#[test]
fn test_qa_build_gd_alignment_target() {
let freq = log_freq_grid(20.0, 5000.0, 300);
let ch0 = make_delayed_channel(&freq, 1.0, 0.95);
let ch1 = make_delayed_channel(&freq, 4.0, 0.95);
let channels = vec![ch0, ch1];
let band = (20.0, 5000.0);
let config = GdOptConfig {
ap_per_channel: 0,
optimize_polarity: false,
max_iter: 3000,
popsize: 20,
tol: 1e-10,
seed: Some(55),
..Default::default()
};
let result = optimize_group_delay(&channels, band, &config).unwrap();
let target = build_gd_alignment_target(&channels, &result, &config);
assert_eq!(
target.per_channel_delay_ms.len(),
channels.len(),
"per_channel_delay_ms length mismatch: got {}, expected {}",
target.per_channel_delay_ms.len(),
channels.len(),
);
assert!(
!target.freq.is_empty(),
"GdAlignmentTarget freq grid is empty"
);
assert!(
target.freq[0] >= band.0 - 1e-6,
"freq[0]={} below band_lo={}",
target.freq[0],
band.0,
);
assert!(
*target.freq.last().unwrap() <= band.1 + 1e-6,
"freq[last]={} above band_hi={}",
target.freq.last().unwrap(),
band.1,
);
assert_eq!(
target.sum_gd_reference_ms.len(),
target.freq.len(),
"sum_gd_reference_ms and freq length mismatch: {} vs {}",
target.sum_gd_reference_ms.len(),
target.freq.len(),
);
}
}