use crate::Curve;
use crate::error::{AutoeqError, Result};
use log::{debug, info, warn};
use math_audio_dsp::analysis::{compute_average_response, find_db_point};
use math_audio_iir_fir::Biquad;
use std::collections::HashMap;
use std::path::Path;
use std::sync::{Arc, Mutex};
use super::config::validate_room_config;
use super::fir;
use super::group_delay;
use super::output;
use super::phase_alignment;
use super::types::{
ChannelDspChain, DspChainOutput, MeasurementSource, OptimizationMetadata, OptimizerConfig,
ProcessingMode, RoomConfig, SpeakerConfig, SystemModel, TargetCurveConfig,
};
pub(super) use super::speaker_eq::optimize_eq_with_optional_schroeder;
use super::speaker_eq::process_single_speaker;
use super::crossover_utils::check_group_consistency;
use super::group_processing::{
process_cardioid, process_dba, process_multisub_group, process_speaker_group,
};
type SpeakerProcessResult = std::result::Result<
(
String,
ChannelDspChain,
f64,
f64,
crate::Curve,
crate::Curve,
Vec<crate::iir::Biquad>,
f64,
Option<f64>,
Option<Vec<f64>>,
),
AutoeqError,
>;
type MixedModeResult = (
ChannelDspChain,
f64,
f64,
Curve,
Curve,
Vec<Biquad>,
f64,
Option<f64>,
Option<Vec<f64>>,
);
pub(super) fn detect_passband_and_mean(curve: &Curve) -> (Option<(f64, f64)>, f64) {
let freqs_f32: Vec<f32> = curve.freq.iter().map(|&f| f as f32).collect();
let spl_f32: Vec<f32> = curve.spl.iter().map(|&s| s as f32).collect();
if spl_f32.is_empty() {
return (None, 0.0);
}
let smoothed_1oct = crate::read::smooth_one_over_n_octave(curve, 1);
let spl_1oct: Vec<f32> = smoothed_1oct.spl.iter().map(|&s| s as f32).collect();
let mut sorted_spl: Vec<f32> = spl_1oct.clone();
sorted_spl.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let median_spl = sorted_spl[sorted_spl.len() / 2];
if median_spl < -100.0 {
return (None, 0.0);
}
let threshold = median_spl - 10.0;
let f_low_1 = find_db_point(&freqs_f32, &spl_1oct, threshold, true).unwrap_or(freqs_f32[0]);
let f_high_1 = find_db_point(&freqs_f32, &spl_1oct, threshold, false)
.unwrap_or(freqs_f32[freqs_f32.len() - 1]);
let spl_2oct: Vec<f32> = {
let double_smooth = crate::read::smooth_one_over_n_octave(&smoothed_1oct, 1);
double_smooth.spl.iter().map(|&s| s as f32).collect()
};
let f_low_2 = find_db_point(&freqs_f32, &spl_2oct, threshold, true).unwrap_or(freqs_f32[0]);
let f_high_2 = find_db_point(&freqs_f32, &spl_2oct, threshold, false)
.unwrap_or(freqs_f32[freqs_f32.len() - 1]);
let f_low = f_low_1.min(f_low_2);
let f_high = f_high_1.max(f_high_2);
let norm_range_f32 = Some((f_low, f_high));
let mean = compute_average_response(&freqs_f32, &spl_f32, norm_range_f32) as f64;
(Some((f_low as f64, f_high as f64)), mean)
}
fn post_generate_fir(
name: &str,
initial_curve: &Curve,
final_curve: &Curve,
config: &super::types::OptimizerConfig,
target_curve: Option<&super::types::TargetCurveConfig>,
sample_rate: f64,
output_dir: Option<&Path>,
) -> Option<Vec<f64>> {
let fir_input = match config.processing_mode {
ProcessingMode::Hybrid => final_curve,
_ => initial_curve,
};
match fir::generate_fir_correction(fir_input, config, target_curve, sample_rate) {
Ok(coeffs) => {
if let Some(out_dir) = output_dir {
let filename = format!("{}_fir.wav", name);
let wav_path = out_dir.join(&filename);
if let Err(e) = crate::fir::save_fir_to_wav(&coeffs, sample_rate as u32, &wav_path)
{
warn!("Failed to save FIR WAV for {}: {}", name, e);
} else {
info!(" Saved FIR filter to {}", wav_path.display());
}
}
Some(coeffs)
}
Err(e) => {
warn!("FIR generation failed for {}: {}", name, e);
None
}
}
}
fn post_generate_mixed_phase_fir(
name: &str,
initial_curve: &Curve,
config: &super::types::OptimizerConfig,
sample_rate: f64,
output_dir: Option<&Path>,
) -> Option<Vec<f64>> {
let phase = initial_curve.phase.as_ref()?;
if phase.is_empty() {
return None;
}
let mp_config = match &config.mixed_phase {
Some(sc) => super::mixed_phase::MixedPhaseConfig {
max_fir_length_ms: sc.max_fir_length_ms,
pre_ringing_threshold_db: sc.pre_ringing_threshold_db,
min_spatial_depth: sc.min_spatial_depth,
phase_smoothing_octaves: sc.phase_smoothing_octaves,
},
None => super::mixed_phase::MixedPhaseConfig::default(),
};
match super::mixed_phase::decompose_phase(initial_curve, &mp_config) {
Ok((_min_phase, _excess, delay_ms, residual)) => {
info!(
" Mixed-phase (post-workflow) '{}': delay={:.2} ms",
name, delay_ms
);
let coeffs = super::mixed_phase::generate_excess_phase_fir(
&initial_curve.freq,
&residual,
&mp_config,
sample_rate,
);
if let Some(out_dir) = output_dir {
let filename = format!("{}_excess_phase_fir.wav", name);
let wav_path = out_dir.join(&filename);
if let Err(e) = crate::fir::save_fir_to_wav(&coeffs, sample_rate as u32, &wav_path)
{
warn!("Failed to save excess phase FIR for {}: {}", name, e);
} else {
info!(" Saved excess phase FIR to {}", wav_path.display());
}
}
Some(coeffs)
}
Err(e) => {
warn!(
" Mixed-phase decomposition failed for '{}': {}. Using IIR only.",
name, e
);
None
}
}
}
fn apply_phase_correction(
name: &str,
ch: &mut ChannelOptimizationResult,
chain: &mut super::types::ChannelDspChain,
config: &super::types::MixedPhaseSerdeConfig,
sample_rate: f64,
output_dir: Option<&Path>,
) {
let phase = match ch.initial_curve.phase.as_ref() {
Some(p) if !p.is_empty() => p,
_ => return,
};
let _ = phase;
let mp_config = super::mixed_phase::MixedPhaseConfig {
max_fir_length_ms: config.max_fir_length_ms,
pre_ringing_threshold_db: config.pre_ringing_threshold_db,
min_spatial_depth: config.min_spatial_depth,
phase_smoothing_octaves: config.phase_smoothing_octaves,
};
let phase_fir = match super::mixed_phase::decompose_phase(&ch.initial_curve, &mp_config) {
Ok((_min, _excess, delay_ms, residual)) => {
info!(
" Phase correction '{}': delay={:.2} ms, generating phase-only FIR",
name, delay_ms
);
super::mixed_phase::generate_excess_phase_fir(
&ch.initial_curve.freq,
&residual,
&mp_config,
sample_rate,
)
}
Err(e) => {
warn!(" Phase correction failed for '{}': {}", name, e);
return;
}
};
let filename = format!("{}_phase_correction.wav", name);
if let Some(out_dir) = output_dir {
let wav_path = out_dir.join(&filename);
if let Err(e) = crate::fir::save_fir_to_wav(&phase_fir, sample_rate as u32, &wav_path) {
warn!("Failed to save phase correction FIR for {}: {}", name, e);
} else {
info!(" Saved phase correction FIR to {}", wav_path.display());
}
}
chain
.plugins
.push(super::output::create_convolution_plugin(&filename));
if let Some(ref existing) = ch.fir_coeffs {
ch.fir_coeffs = Some(convolve(existing, &phase_fir));
} else {
ch.fir_coeffs = Some(phase_fir);
}
}
fn convolve(a: &[f64], b: &[f64]) -> Vec<f64> {
let len = a.len() + b.len() - 1;
let mut out = vec![0.0; len];
for (i, &av) in a.iter().enumerate() {
for (j, &bv) in b.iter().enumerate() {
out[i + j] += av * bv;
}
}
out
}
const LEVEL_DIFFERENCE_WARNING_THRESHOLD: f64 = 6.0;
const ARRIVAL_TIME_WARNING_THRESHOLD_MS: f64 = 50.0;
fn find_sub_main_pairings(
config: &RoomConfig,
curves: &HashMap<String, crate::Curve>,
) -> Vec<(String, String)> {
let mut pairings = Vec::new();
if let Some(sys) = &config.system {
if let Some(subs) = &sys.subwoofers {
let meas_to_role: HashMap<&String, &String> =
sys.speakers.iter().map(|(r, m)| (m, r)).collect();
for (sub_meas_key, main_role) in &subs.mapping {
if let Some(sub_role) = meas_to_role.get(sub_meas_key) {
pairings.push((sub_role.to_string(), main_role.clone()));
} else {
warn!(
"Subwoofer measurement '{}' not mapped to any output channel",
sub_meas_key
);
}
}
}
} else {
let sub_channel = curves
.keys()
.find(|name| *name == "lfe" || name.starts_with("sub"))
.cloned();
if let Some(sub_name) = sub_channel {
let main_channels: Vec<String> = curves
.keys()
.filter(|name| *name != &sub_name && !name.starts_with("sub"))
.cloned()
.collect();
for main in main_channels {
pairings.push((sub_name.clone(), main));
}
}
}
pairings
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CallbackAction {
Continue,
Stop,
}
#[derive(Debug, Clone)]
pub struct RoomOptimizationProgress {
pub current_speaker: String,
pub speaker_index: usize,
pub total_speakers: usize,
pub iteration: usize,
pub max_iterations: usize,
pub loss: f64,
pub overall_progress: f64,
pub message: Option<String>,
}
pub type RoomOptimizationCallback =
Box<dyn FnMut(&RoomOptimizationProgress) -> CallbackAction + Send>;
pub type SpeakerOptimizationCallback =
Box<dyn FnMut(&RoomOptimizationProgress) -> CallbackAction + Send>;
#[derive(Debug, Clone)]
pub struct ChannelOptimizationResult {
pub name: String,
pub pre_score: f64,
pub post_score: f64,
pub initial_curve: Curve,
pub final_curve: Curve,
pub biquads: Vec<Biquad>,
pub fir_coeffs: Option<Vec<f64>>,
}
#[derive(Debug, Clone)]
pub struct RoomOptimizationResult {
pub channels: HashMap<String, ChannelDspChain>,
pub channel_results: HashMap<String, ChannelOptimizationResult>,
pub combined_pre_score: f64,
pub combined_post_score: f64,
pub metadata: OptimizationMetadata,
}
impl RoomOptimizationResult {
pub fn to_dsp_chain_output(&self) -> DspChainOutput {
output::create_dsp_chain_output(self.channels.clone(), Some(self.metadata.clone()))
}
}
#[derive(Debug, Clone)]
pub struct SpeakerOptimizationResult {
pub chain: ChannelDspChain,
pub pre_score: f64,
pub post_score: f64,
pub initial_curve: Curve,
pub final_curve: Curve,
pub biquads: Vec<Biquad>,
pub fir_coeffs: Option<Vec<f64>>,
}
pub fn optimize_room(
config: &RoomConfig,
sample_rate: f64,
mut callback: Option<RoomOptimizationCallback>,
output_dir: Option<&Path>,
) -> Result<RoomOptimizationResult> {
let mut config = config.clone();
config.optimizer.migrate_target_config();
if config
.optimizer
.cea2034_correction
.as_ref()
.is_some_and(|c| c.enabled)
{
let cache = super::cea2034_correction::pre_fetch_all_cea2034(&config);
if !cache.is_empty() {
info!(
" CEA2034 cache: loaded data for {} speaker(s)",
cache.len()
);
config.cea2034_cache = Some(cache);
}
}
let config = &config;
let validation = validate_room_config(config);
validation.print_results();
if !validation.is_valid {
return Err(AutoeqError::OptimizationFailed {
message: format!(
"Configuration validation failed with {} errors",
validation.errors.len()
),
});
}
fn send_progress(
cb: &mut Option<RoomOptimizationCallback>,
progress: &RoomOptimizationProgress,
) -> bool {
if let Some(f) = cb {
f(progress) == CallbackAction::Stop
} else {
false
}
}
if let Some(sys) = &config.system {
let has_group = sys
.speakers
.values()
.any(|key| matches!(config.speakers.get(key), Some(SpeakerConfig::Group(_))));
let use_generic_for_stereo = sys.model == SystemModel::Stereo
&& sys.subwoofers.is_none()
&& (config
.optimizer
.excursion_protection
.as_ref()
.is_some_and(|e| e.enabled)
|| config.optimizer.target_response.is_some()
|| config.optimizer.target_tilt.is_some()
|| config
.optimizer
.broadband_target_matching
.as_ref()
.is_some_and(|b| b.enabled));
if use_generic_for_stereo {
info!("Stereo 2.0 with excursion/tilt/broadband features, using generic path");
}
if !has_group && !use_generic_for_stereo {
let workflow_name = match sys.model {
SystemModel::Stereo => {
if sys.subwoofers.is_some() {
"Stereo 2.1"
} else {
"Stereo 2.0"
}
}
SystemModel::HomeCinema => "Home Cinema",
SystemModel::Custom => "Custom",
};
if sys.model != SystemModel::Custom {
send_progress(
&mut callback,
&RoomOptimizationProgress {
current_speaker: String::new(),
speaker_index: 0,
total_speakers: sys.speakers.len(),
iteration: 0,
max_iterations: 0,
loss: 0.0,
overall_progress: 0.0,
message: Some(format!(
"Starting {} workflow ({} channels)",
workflow_name,
sys.speakers.len()
)),
},
);
}
let workflow_result = match sys.model {
SystemModel::Stereo => {
if sys.subwoofers.is_some() {
Some(super::workflows::optimize_stereo_2_1(
config,
sys,
sample_rate,
output_dir.unwrap_or(Path::new(".")),
))
} else {
Some(super::workflows::optimize_stereo_2_0(
config,
sys,
sample_rate,
output_dir.unwrap_or(Path::new(".")),
))
}
}
SystemModel::HomeCinema => Some(super::workflows::optimize_home_cinema(
config,
sys,
sample_rate,
output_dir.unwrap_or(Path::new(".")),
)),
SystemModel::Custom => None, };
if let Some(result) = workflow_result {
let mut result = result?;
let summary: Vec<String> = result
.channel_results
.iter()
.map(|(name, ch)| {
format!(" {}: {:.4} -> {:.4}", name, ch.pre_score, ch.post_score)
})
.collect();
send_progress(
&mut callback,
&RoomOptimizationProgress {
current_speaker: String::new(),
speaker_index: result.channel_results.len(),
total_speakers: result.channel_results.len(),
iteration: 0,
max_iterations: 0,
loss: result.combined_post_score,
overall_progress: 1.0,
message: Some(format!(
"{} workflow complete:\n{}",
workflow_name,
summary.join("\n")
)),
},
);
if matches!(
config.optimizer.processing_mode,
ProcessingMode::PhaseLinear | ProcessingMode::Hybrid
) {
let out_dir = output_dir.unwrap_or(Path::new("."));
for (name, ch) in result.channel_results.iter_mut() {
if ch.fir_coeffs.is_some() {
continue;
}
ch.fir_coeffs = post_generate_fir(
name,
&ch.initial_curve,
&ch.final_curve,
&config.optimizer,
config.target_curve.as_ref(),
sample_rate,
Some(out_dir),
);
}
}
if config.optimizer.processing_mode == ProcessingMode::MixedPhase {
let out_dir = output_dir.unwrap_or(Path::new("."));
for (name, ch) in result.channel_results.iter_mut() {
if ch.fir_coeffs.is_some() {
continue;
}
ch.fir_coeffs = post_generate_mixed_phase_fir(
name,
&ch.initial_curve,
&config.optimizer,
sample_rate,
Some(out_dir),
);
if ch.fir_coeffs.is_some()
&& let Some(chain) = result.channels.get_mut(name)
{
let filename = format!("{}_excess_phase_fir.wav", name);
chain
.plugins
.push(super::output::create_convolution_plugin(&filename));
}
}
}
if let Some(ref pc_config) = config.optimizer.phase_correction {
let out_dir = output_dir.unwrap_or(Path::new("."));
let names: Vec<String> = result.channel_results.keys().cloned().collect();
for name in &names {
if let Some(ch) = result.channel_results.get_mut(name)
&& let Some(chain) = result.channels.get_mut(name)
{
apply_phase_correction(
name,
ch,
chain,
pc_config,
sample_rate,
Some(out_dir),
);
}
}
}
for (channel_name, ch_result) in &result.channel_results {
let delay_ms = result
.channels
.get(channel_name)
.and_then(|chain| chain.plugins.iter().find(|p| p.plugin_type == "delay"))
.and_then(|p| p.parameters.get("delay_ms").and_then(|v| v.as_f64()))
.unwrap_or(0.0);
if let Some((pre_ir, post_ir)) =
super::ir_waveform::compute_channel_ir_waveforms(
&ch_result.initial_curve,
&ch_result.biquads,
ch_result.fir_coeffs.as_deref(),
delay_ms,
sample_rate,
)
&& let Some(chain) = result.channels.get_mut(channel_name)
{
chain.pre_ir = Some(pre_ir);
chain.post_ir = Some(post_ir);
}
}
if result.channel_results.len() > 1 {
compute_and_correct_icd(&mut result, config, sample_rate);
}
return Ok(result);
}
}
}
let channels_to_process: Vec<(String, SpeakerConfig)> = if let Some(sys) = &config.system {
info!("Using SystemConfig for channel mapping");
sys.speakers
.iter()
.filter_map(|(role, key)| match config.speakers.get(key) {
Some(cfg) => Some((role.clone(), cfg.clone())),
None => {
warn!(
"System config references missing speaker key '{}' for role '{}'",
key, role
);
None
}
})
.collect()
} else {
config
.speakers
.iter()
.map(|(k, v)| (k.clone(), v.clone()))
.collect()
};
let total_speakers = channels_to_process.len();
info!("Processing {} channels", total_speakers);
let shared_mean_spl: Option<f64> = if total_speakers > 1 {
let min_freq = config.optimizer.min_freq;
let max_freq = config.optimizer.max_freq;
let mut channel_means: Vec<f64> = Vec::new();
let mut excluded_group_count = 0_usize;
for (_name, speaker_config) in &channels_to_process {
if let SpeakerConfig::Single(source) = speaker_config
&& let Ok(curve) = crate::read::load_source(source)
{
let freqs_f32: Vec<f32> = curve.freq.iter().map(|&f| f as f32).collect();
let spl_f32: Vec<f32> = curve.spl.iter().map(|&s| s as f32).collect();
let mean = compute_average_response(
&freqs_f32,
&spl_f32,
Some((min_freq as f32, max_freq as f32)),
) as f64;
channel_means.push(mean);
} else if !matches!(speaker_config, SpeakerConfig::Single(_)) {
excluded_group_count += 1;
}
}
if excluded_group_count > 0 {
info!(
"Shared mean pre-pass: {} non-Single speaker(s) excluded (Group/MultiSub/DBA/Cardioid)",
excluded_group_count
);
}
if channel_means.len() > 1 {
let avg = channel_means.iter().sum::<f64>() / channel_means.len() as f64;
info!(
"Shared target level: {:.1} dB (average of {} channels)",
avg,
channel_means.len()
);
Some(avg)
} else {
None
}
} else {
None
};
send_progress(
&mut callback,
&RoomOptimizationProgress {
current_speaker: String::new(),
speaker_index: 0,
total_speakers,
iteration: 0,
max_iterations: 0,
loss: 0.0,
overall_progress: 0.0,
message: Some(format!(
"Starting optimization for {} channels",
total_speakers
)),
},
);
let params_per_filter = match config.optimizer.peq_model.as_str() {
"free" | "ls-pk-hs" => 4,
_ => 3,
};
let n_params = config.optimizer.num_filters * params_per_filter;
let n_free = n_params.max(1); let desired_pop = config
.optimizer
.population
.max(1)
.min(config.optimizer.max_iter.max(1));
let pop_multiplier = desired_pop.div_ceil(n_free).max(4);
let population_size = pop_multiplier * n_free;
let max_iterations =
(config.optimizer.max_iter.saturating_sub(population_size) / population_size).max(5000);
info!(
"DE budget: {} params, population_size={}, max_generations={} (from max_iter={}, floor=5000)",
n_params, population_size, max_iterations, config.optimizer.max_iter
);
let callback_shared: Arc<Mutex<Option<RoomOptimizationCallback>>> =
Arc::new(Mutex::new(callback));
let mut results: Vec<SpeakerProcessResult> = Vec::with_capacity(total_speakers);
for (speaker_idx, (channel_name, speaker_config)) in channels_to_process.into_iter().enumerate()
{
info!("Processing channel: {}", channel_name);
{
let mut guard = callback_shared.lock().unwrap();
let stop = send_progress(
&mut guard,
&RoomOptimizationProgress {
current_speaker: channel_name.clone(),
speaker_index: speaker_idx,
total_speakers,
iteration: 0,
max_iterations: 0,
loss: 0.0,
overall_progress: speaker_idx as f64 / total_speakers as f64,
message: Some(format!("Processing channel: {}", channel_name)),
},
);
if stop {
break;
}
}
let eq_callback: Option<crate::optim::OptimProgressCallback> = {
let cb = Arc::clone(&callback_shared);
let name = channel_name.clone();
let si = speaker_idx;
let ts = total_speakers;
let mi = max_iterations;
Some(Box::new(move |iter: usize, loss: f64| {
let base_progress = si as f64 / ts as f64;
let speaker_progress = if mi > 0 { iter as f64 / mi as f64 } else { 0.0 };
let overall = (base_progress + speaker_progress / ts as f64).min(1.0);
if let Ok(mut guard) = cb.lock()
&& let Some(room_cb) = guard.as_mut()
{
let action = room_cb(&RoomOptimizationProgress {
current_speaker: name.clone(),
speaker_index: si,
total_speakers: ts,
iteration: iter,
max_iterations: mi,
loss,
overall_progress: overall,
message: None,
});
return match action {
CallbackAction::Continue => crate::de::CallbackAction::Continue,
CallbackAction::Stop => crate::de::CallbackAction::Stop,
};
}
crate::de::CallbackAction::Continue
}))
};
let result = process_speaker_internal(
&channel_name,
&speaker_config,
config,
sample_rate,
output_dir,
eq_callback,
shared_mean_spl,
);
match result {
Ok((
chain,
pre_score,
post_score,
initial_curve,
final_curve,
biquads,
mean_spl,
arrival_time_ms,
fir_coeffs,
)) => {
{
let mut guard = callback_shared.lock().unwrap();
let stop = send_progress(
&mut guard,
&RoomOptimizationProgress {
current_speaker: channel_name.clone(),
speaker_index: speaker_idx,
total_speakers,
iteration: 0,
max_iterations: 0,
loss: post_score,
overall_progress: (speaker_idx + 1) as f64 / total_speakers as f64,
message: Some(format!(
"Channel {}: {:.4} -> {:.4}",
channel_name, pre_score, post_score
)),
},
);
let _ = stop;
}
results.push(Ok((
channel_name,
chain,
pre_score,
post_score,
initial_curve,
final_curve,
biquads,
mean_spl,
arrival_time_ms,
fir_coeffs,
)));
}
Err(e) => {
results.push(Err(e));
}
}
}
let mut channel_chains: HashMap<String, ChannelDspChain> = HashMap::new();
let mut channel_results: HashMap<String, ChannelOptimizationResult> = HashMap::new();
let mut pre_scores: Vec<f64> = Vec::new();
let mut post_scores: Vec<f64> = Vec::new();
let mut curves: HashMap<String, crate::Curve> = HashMap::new();
let mut channel_means: HashMap<String, f64> = HashMap::new();
let mut channel_arrivals: HashMap<String, f64> = HashMap::new();
for res in results {
let (
channel_name,
chain,
pre_score,
post_score,
initial_curve,
final_curve,
biquads,
mean_spl,
arrival_time_ms,
fir_coeffs,
) = res?;
channel_chains.insert(channel_name.clone(), chain);
curves.insert(channel_name.clone(), final_curve.clone());
pre_scores.push(pre_score);
post_scores.push(post_score);
channel_means.insert(channel_name.clone(), mean_spl);
if let Some(arrival_ms) = arrival_time_ms {
channel_arrivals.insert(channel_name.clone(), arrival_ms);
}
let fir_coeffs = if fir_coeffs.is_none()
&& !matches!(
config.optimizer.processing_mode,
ProcessingMode::LowLatency | ProcessingMode::MixedPhase
) {
post_generate_fir(
&channel_name,
&initial_curve,
&final_curve,
&config.optimizer,
config.target_curve.as_ref(),
sample_rate,
output_dir,
)
} else {
fir_coeffs
};
channel_results.insert(
channel_name.clone(),
ChannelOptimizationResult {
name: channel_name,
pre_score,
post_score,
initial_curve,
final_curve,
biquads,
fir_coeffs,
},
);
}
let phase_ir_sync = channel_arrivals.is_empty() && channel_results.len() > 1;
if phase_ir_sync {
for (channel_name, result) in &channel_results {
if let Some(arrival_ms) =
super::time_align::estimate_arrival_from_phase(&result.initial_curve, 200.0, 2000.0)
{
channel_arrivals.insert(channel_name.clone(), arrival_ms);
}
}
if channel_arrivals.len() > 1 {
info!(
"Auto IR sync: phase-estimated arrival times for {} channels",
channel_arrivals.len()
);
for (name, arrival) in &channel_arrivals {
info!(
" Channel '{}': phase-estimated arrival = {:.2} ms",
name, arrival
);
}
} else {
channel_arrivals.clear();
}
}
if (config.optimizer.allow_delay() || phase_ir_sync) && channel_arrivals.len() > 1 {
let arrivals: Vec<f64> = channel_arrivals.values().copied().collect();
let min_arrival = arrivals.iter().cloned().fold(f64::INFINITY, f64::min);
let max_arrival = arrivals.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let arrival_spread = max_arrival - min_arrival;
if arrival_spread > ARRIVAL_TIME_WARNING_THRESHOLD_MS {
warn!(
"Channel arrival times differ by {:.1} ms (threshold: {:.1} ms). \
This may indicate measurement issues or very different speaker distances.",
arrival_spread, ARRIVAL_TIME_WARNING_THRESHOLD_MS
);
for (name, arrival) in &channel_arrivals {
info!(" Channel '{}': arrival time = {:.2} ms", name, arrival);
}
}
let alignment_delays = super::time_align::calculate_alignment_delays(&channel_arrivals);
for (channel_name, delay_ms) in &alignment_delays {
if *delay_ms > 0.01
&& let Some(chain) = channel_chains.get_mut(channel_name)
{
chain
.plugins
.insert(0, output::create_delay_plugin(*delay_ms));
info!(
" Channel '{}': added {:.3} ms delay for time alignment",
channel_name, delay_ms
);
}
}
} else if channel_arrivals.is_empty() && config.speakers.len() > 1 {
info!("No arrival time data (WAV or phase) available for time alignment. Skipping.");
}
if curves.len() > 1 {
let min_freq = config.optimizer.min_freq;
let max_freq = config.optimizer.max_freq;
let sample_rate = config
.recording_config
.as_ref()
.and_then(|rc| rc.playback_sample_rate)
.unwrap_or(48000) as f64;
let mut post_eq_means: HashMap<String, f64> = HashMap::new();
for (channel_name, final_curve) in &curves {
let freqs_f32: Vec<f32> = final_curve.freq.iter().map(|&f| f as f32).collect();
let spl_f32: Vec<f32> = final_curve.spl.iter().map(|&s| s as f32).collect();
let post_mean = compute_average_response(
&freqs_f32,
&spl_f32,
Some((min_freq as f32, max_freq as f32)),
) as f64;
post_eq_means.insert(channel_name.clone(), post_mean);
}
let means: Vec<f64> = post_eq_means.values().copied().collect();
let min_mean = means.iter().cloned().fold(f64::INFINITY, f64::min);
let max_mean = means.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let level_spread = max_mean - min_mean;
info!(
"Post-EQ spectral alignment: level spread = {:.2} dB across {} channels",
level_spread,
post_eq_means.len()
);
for (name, mean) in &post_eq_means {
info!(" Channel '{}': post-EQ mean SPL = {:.1} dB", name, mean);
}
if level_spread > LEVEL_DIFFERENCE_WARNING_THRESHOLD {
warn!(
"Channel levels differ by {:.1} dB (threshold: {:.1} dB). \
This may indicate measurement issues (mic placement, cable problems, etc.).",
level_spread, LEVEL_DIFFERENCE_WARNING_THRESHOLD
);
}
let alignment_results = super::spectral_align::compute_spectral_alignment(
&curves,
sample_rate,
min_freq,
max_freq,
);
super::spectral_align::log_spectral_alignment(&alignment_results);
for (channel_name, result) in &alignment_results {
if let Some(chain) = channel_chains.get_mut(channel_name) {
let (eq_plugin, gain_plugin) =
super::spectral_align::create_alignment_plugins(result, sample_rate);
if let Some(eq) = eq_plugin {
chain.plugins.push(eq);
}
if let Some(gain) = gain_plugin {
chain.plugins.push(gain);
}
}
}
}
if let Some(vog_config) = &config.optimizer.vog
&& vog_config.enabled
{
info!(
"Running Voice of God alignment (reference: '{}')...",
vog_config.reference_channel
);
let corrected_curves: HashMap<String, Curve> = channel_results
.iter()
.map(|(name, result)| (name.clone(), result.final_curve.clone()))
.collect();
match super::voice_of_god::compute_voice_of_god(
&corrected_curves,
&vog_config.reference_channel,
sample_rate,
config.optimizer.min_freq,
config.optimizer.max_freq,
) {
Ok(vog_results) => {
for (channel_name, vog_result) in &vog_results {
let plugins = super::voice_of_god::create_vog_plugins(vog_result, sample_rate);
if !plugins.is_empty()
&& let Some(chain) = channel_chains.get_mut(channel_name)
{
for plugin in plugins {
chain.plugins.push(plugin);
}
}
}
}
Err(e) => {
warn!("Voice of God optimization failed: {}", e);
}
}
}
let mut phase_alignment_results: HashMap<String, (f64, bool)> = HashMap::new();
if config.optimizer.allow_delay()
&& let Some(phase_config) = &config.optimizer.phase_alignment
&& phase_config.enabled
{
let pairings = find_sub_main_pairings(config, &curves);
if pairings.is_empty() {
warn!("Phase alignment enabled but no valid sub-main pairings found.");
} else {
info!("Running phase alignment optimization...");
send_progress(
&mut callback_shared.lock().unwrap(),
&RoomOptimizationProgress {
current_speaker: String::new(),
speaker_index: 0,
total_speakers: pairings.len(),
iteration: 0,
max_iterations: 0,
loss: 0.0,
overall_progress: 0.0,
message: Some("Running phase alignment...".to_string()),
},
);
for (sub_name, main_name) in &pairings {
let sub_curve = match curves.get(sub_name) {
Some(c) => c,
None => {
warn!(
"Subwoofer channel '{}' not found for phase alignment",
sub_name
);
continue;
}
};
if let Some(speaker_curve) = curves.get(main_name) {
if sub_curve.phase.is_some() && speaker_curve.phase.is_some() {
match phase_alignment::optimize_phase_alignment(
sub_curve,
speaker_curve,
phase_config,
) {
Ok(result) => {
info!(
" Phase alignment '{}' with '{}': delay={:.2}ms, invert={}, improvement={:.2}dB",
main_name,
sub_name,
result.delay_ms,
result.invert_polarity,
result.improvement_db
);
phase_alignment_results.insert(
main_name.clone(),
(result.delay_ms, result.invert_polarity),
);
}
Err(e) => {
warn!(" Phase alignment failed for '{}': {}", main_name, e);
}
}
} else {
debug!(
" Skipping phase alignment for '{}': no phase data available",
main_name
);
}
}
}
}
}
for (speaker_name, (delay_ms, invert)) in &phase_alignment_results {
if let Some(chain) = channel_chains.get_mut(speaker_name) {
if *invert {
let invert_plugin = output::create_gain_plugin_with_invert(0.0, true);
chain.plugins.insert(0, invert_plugin);
info!(" Applied polarity inversion to '{}'", speaker_name);
}
if *delay_ms > 0.01 {
output::add_delay_plugin(chain, *delay_ms);
info!(
" Applied {:.3} ms phase alignment delay to '{}'",
delay_ms, speaker_name
);
}
}
}
if let Some(gd_opt) = &config.optimizer.gd_opt
&& gd_opt.enabled
&& config.optimizer.processing_mode == ProcessingMode::LowLatency
{
info!("Running Group Delay Optimization (IIR Mode)...");
let pairings = find_sub_main_pairings(config, &curves);
if pairings.is_empty() {
warn!("GD-Opt enabled but no valid sub-main pairings found.");
} else {
send_progress(
&mut callback_shared.lock().unwrap(),
&RoomOptimizationProgress {
current_speaker: String::new(),
speaker_index: 0,
total_speakers: pairings.len(),
iteration: 0,
max_iterations: 0,
loss: 0.0,
overall_progress: 0.0,
message: Some("Running group delay optimization...".to_string()),
},
);
}
let min_freq = config.optimizer.min_freq;
let max_freq = 200.0;
for (sub_name, main_name) in pairings {
if let (Some(sub_curve), Some(main_curve)) =
(curves.get(&sub_name), curves.get(&main_name))
{
info!(" Optimizing GD for '{}' vs '{}'", main_name, sub_name);
send_progress(
&mut callback_shared.lock().unwrap(),
&RoomOptimizationProgress {
current_speaker: format!("GD {}", main_name),
speaker_index: 0,
total_speakers: 1,
iteration: 0,
max_iterations: 0,
loss: 0.0,
overall_progress: 0.0,
message: Some(format!("Optimizing GD for '{}'", main_name)),
},
);
match group_delay::optimize_gd_iir(
sub_curve,
main_curve,
min_freq,
max_freq,
sample_rate,
) {
Ok(filters) => {
if !filters.is_empty() {
info!(
" Generated {} All-Pass filters for GD alignment",
filters.len()
);
if let Some(chain) = channel_chains.get_mut(&main_name) {
let plugin = output::create_eq_plugin(&filters);
chain.plugins.push(plugin);
}
} else {
info!(" No AP filters needed");
}
}
Err(e) => {
warn!(" GD optimization failed for '{}': {}", main_name, e);
}
}
} else {
warn!(
"GD-Opt: Channel '{}' or '{}' not found in results",
sub_name, main_name
);
}
}
}
if let Some(ref pc_config) = config.optimizer.phase_correction {
let names: Vec<String> = channel_results.keys().cloned().collect();
for name in &names {
if let Some(ch) = channel_results.get_mut(name.as_str())
&& let Some(chain) = channel_chains.get_mut(name.as_str())
{
apply_phase_correction(name, ch, chain, pc_config, sample_rate, output_dir);
}
}
}
for (channel_name, result) in &channel_results {
let delay_ms = channel_chains
.get(channel_name)
.and_then(|chain| chain.plugins.iter().find(|p| p.plugin_type == "delay"))
.and_then(|p| p.parameters.get("delay_ms").and_then(|v| v.as_f64()))
.unwrap_or(0.0);
if let Some((pre_ir, post_ir)) = super::ir_waveform::compute_channel_ir_waveforms(
&result.initial_curve,
&result.biquads,
result.fir_coeffs.as_deref(),
delay_ms,
sample_rate,
) && let Some(chain) = channel_chains.get_mut(channel_name)
{
chain.pre_ir = Some(pre_ir);
chain.post_ir = Some(post_ir);
}
}
let avg_pre_score = if !pre_scores.is_empty() {
pre_scores.iter().sum::<f64>() / pre_scores.len() as f64
} else {
0.0
};
let avg_post_score = if !post_scores.is_empty() {
post_scores.iter().sum::<f64>() / post_scores.len() as f64
} else {
0.0
};
info!(
"Average pre-score: {:.4}, post-score: {:.4}",
avg_pre_score, avg_post_score
);
let acoustic_groups = identify_acoustic_groups(config);
for (group_name, group_channels) in &acoustic_groups {
if group_channels.len() > 1 {
debug!("Acoustic Group '{}': {:?}", group_name, group_channels);
check_group_consistency(group_name, group_channels, &channel_means, &curves);
}
}
let metadata = OptimizationMetadata {
pre_score: avg_pre_score,
post_score: avg_post_score,
algorithm: config.optimizer.algorithm.clone(),
iterations: config.optimizer.max_iter,
timestamp: chrono::Utc::now().to_rfc3339(),
inter_channel_deviation: None,
};
let mut result = RoomOptimizationResult {
channels: channel_chains,
channel_results,
combined_pre_score: avg_pre_score,
combined_post_score: avg_post_score,
metadata,
};
if curves.len() > 1 {
compute_and_correct_icd(&mut result, config, sample_rate);
}
Ok(result)
}
fn compute_and_correct_icd(
result: &mut RoomOptimizationResult,
config: &RoomConfig,
sample_rate: f64,
) {
let final_curves: HashMap<String, crate::Curve> = result
.channel_results
.iter()
.map(|(name, ch)| (name.clone(), ch.final_curve.clone()))
.collect();
let f3 = final_curves
.values()
.filter_map(|c| super::excursion::detect_f3(c, None).ok().map(|r| r.f3_hz))
.reduce(f64::min)
.unwrap_or(50.0);
let icd = super::spectral_align::compute_inter_channel_deviation(&final_curves, f3);
info!(
"Inter-channel deviation: midrange_rms={:.2}dB, peak={:.1}dB @{:.0}Hz, passband_rms={:.2}dB",
icd.midrange_rms_db, icd.midrange_peak_db, icd.midrange_peak_freq, icd.passband_rms_db,
);
let matching_cfg = config.optimizer.channel_matching.as_ref();
let enabled = matching_cfg.is_some_and(|c| c.enabled);
let threshold = matching_cfg.map_or(1.5, |c| c.threshold_db);
let max_filters = matching_cfg.map_or(3, |c| c.max_filters);
if enabled && icd.midrange_rms_db > threshold {
info!(
"ICD midrange_rms={:.2}dB > threshold={:.1}dB — applying channel matching correction (max {} filters/ch)",
icd.midrange_rms_db, threshold, max_filters,
);
let corrections = super::spectral_align::correct_inter_channel_deviation(
&final_curves,
f3,
max_filters,
sample_rate,
);
for correction in &corrections {
if let Some(plugin) = &correction.plugin {
info!(
" Channel '{}': {} matching filters",
correction.channel_name,
correction.filters.len(),
);
for f in &correction.filters {
info!(
" PK @ {:.0} Hz, Q={:.2}, gain={:+.1} dB",
f.freq, f.q, f.db_gain,
);
}
if let Some(chain) = result.channels.get_mut(&correction.channel_name) {
chain.plugins.push(plugin.clone());
}
if let Some(ch_result) = result.channel_results.get_mut(&correction.channel_name) {
let resp = crate::response::compute_peq_complex_response(
&correction.filters,
&ch_result.final_curve.freq,
sample_rate,
);
ch_result.final_curve =
crate::response::apply_complex_response(&ch_result.final_curve, &resp);
if let Some(chain) = result.channels.get_mut(&correction.channel_name)
&& let Some(ref display_final) = chain.final_curve
{
let display_curve: crate::Curve = display_final.clone().into();
let display_resp = crate::response::compute_peq_complex_response(
&correction.filters,
&display_curve.freq,
sample_rate,
);
let corrected =
crate::response::apply_complex_response(&display_curve, &display_resp);
chain.final_curve = Some((&corrected).into());
}
}
}
}
let corrected_curves: HashMap<String, crate::Curve> = result
.channel_results
.iter()
.map(|(name, ch)| (name.clone(), ch.final_curve.clone()))
.collect();
let icd_after =
super::spectral_align::compute_inter_channel_deviation(&corrected_curves, f3);
info!(
"ICD after correction: midrange_rms={:.2}dB (was {:.2}dB), peak={:.1}dB @{:.0}Hz",
icd_after.midrange_rms_db,
icd.midrange_rms_db,
icd_after.midrange_peak_db,
icd_after.midrange_peak_freq,
);
result.metadata.inter_channel_deviation = Some(icd_after);
} else {
if enabled {
info!(
"ICD midrange_rms={:.2}dB <= threshold={:.1}dB — no correction needed",
icd.midrange_rms_db, threshold,
);
}
result.metadata.inter_channel_deviation = Some(icd);
}
}
fn identify_acoustic_groups(config: &RoomConfig) -> HashMap<String, Vec<String>> {
let mut groups: HashMap<String, Vec<String>> = HashMap::new();
let mut positioned_channels: HashMap<String, String> = HashMap::new();
for (channel_name, speaker_cfg) in &config.speakers {
if let Some(speaker_name) = speaker_cfg.speaker_name() {
groups
.entry(speaker_name.to_string())
.or_default()
.push(channel_name.clone());
} else {
positioned_channels.insert(channel_name.clone(), channel_name.clone());
}
}
let pairs = [
("L", "R"),
("SL", "SR"),
("SBL", "SBR"),
("TFL", "TFR"),
("TRL", "TRR"),
("FWL", "FWR"),
];
for (p1, p2) in pairs {
if positioned_channels.contains_key(p1) && positioned_channels.contains_key(p2) {
let group_name = format!("{}-{}", p1, p2);
let mut group = Vec::new();
if let Some(c1) = positioned_channels.remove(p1) {
group.push(c1);
}
if let Some(c2) = positioned_channels.remove(p2) {
group.push(c2);
}
groups.insert(group_name, group);
}
}
groups
}
pub fn optimize_speaker(
channel_name: &str,
speaker_config: &SpeakerConfig,
optimizer_config: &OptimizerConfig,
target_curve: Option<&TargetCurveConfig>,
sample_rate: f64,
_callback: Option<SpeakerOptimizationCallback>,
) -> Result<SpeakerOptimizationResult> {
let room_config = RoomConfig {
version: super::types::default_config_version(),
system: None,
speakers: HashMap::new(),
crossovers: None,
target_curve: target_curve.cloned(),
optimizer: optimizer_config.clone(),
recording_config: None,
cea2034_cache: None,
};
let (
chain,
pre_score,
post_score,
initial_curve,
final_curve,
biquads,
_mean_spl,
_arrival_time_ms,
fir_coeffs,
) = process_speaker_internal(
channel_name,
speaker_config,
&room_config,
sample_rate,
None,
None,
None, )?;
Ok(SpeakerOptimizationResult {
chain,
pre_score,
post_score,
initial_curve,
final_curve,
biquads,
fir_coeffs,
})
}
fn process_speaker_internal(
channel_name: &str,
speaker_config: &SpeakerConfig,
room_config: &RoomConfig,
sample_rate: f64,
output_dir: Option<&Path>,
callback: Option<crate::optim::OptimProgressCallback>,
shared_mean_spl: Option<f64>,
) -> Result<MixedModeResult> {
let output_dir = output_dir.unwrap_or(Path::new("."));
match speaker_config {
SpeakerConfig::Single(source) => process_single_speaker(
channel_name,
source,
room_config,
sample_rate,
output_dir,
callback,
None, shared_mean_spl,
),
SpeakerConfig::Group(group) => {
process_speaker_group(channel_name, group, room_config, sample_rate, output_dir)
}
SpeakerConfig::MultiSub(group) => {
process_multisub_group(channel_name, group, room_config, sample_rate, output_dir)
}
SpeakerConfig::Dba(config) => {
process_dba(channel_name, config, room_config, sample_rate, output_dir)
}
SpeakerConfig::Cardioid(config) => {
process_cardioid(channel_name, config, room_config, sample_rate, output_dir)
}
}
}
pub(super) fn extract_wav_path(source: &MeasurementSource) -> Option<String> {
match source {
MeasurementSource::Single(s) => {
if let crate::MeasurementRef::Inline(inline) = &s.measurement {
inline.wav_path.clone()
} else {
None
}
}
MeasurementSource::Multiple(m) => {
m.measurements.first().and_then(|r| {
if let crate::MeasurementRef::Inline(inline) = r {
inline.wav_path.clone()
} else {
None
}
})
}
MeasurementSource::InMemory(_) | MeasurementSource::InMemoryMultiple(_) => None,
}
}