use crate::Curve;
use log::info;
use math_audio_iir_fir::{Biquad, BiquadFilterType, DEFAULT_Q_HIGH_LOW_SHELF};
use math_audio_optimisation::{LMConfigBuilder, levenberg_marquardt};
use ndarray::Array1;
use std::collections::HashMap;
use super::output;
use super::types::PluginConfigWrapper;
pub const LOWSHELF_FREQ: f64 = 200.0;
pub const HIGHSHELF_FREQ: f64 = 4000.0;
const MAX_SHELF_GAIN_DB: f64 = 6.0;
const MAX_FLAT_GAIN_DB: f64 = 12.0;
const MIN_CORRECTION_DB: f64 = 0.3;
#[derive(Debug, Clone)]
pub struct SpectralAlignmentResult {
pub lowshelf_gain_db: f64,
pub highshelf_gain_db: f64,
pub flat_gain_db: f64,
pub residual_rms_db: f64,
}
pub fn compute_spectral_alignment(
curves: &HashMap<String, Curve>,
sample_rate: f64,
min_freq: f64,
max_freq: f64,
) -> HashMap<String, SpectralAlignmentResult> {
if curves.len() <= 1 {
return HashMap::new();
}
let first_curve = curves.values().next().unwrap();
let freq = &first_curve.freq;
let mask: Vec<bool> = freq
.iter()
.map(|&f| f >= min_freq && f <= max_freq)
.collect();
let n_active: usize = mask.iter().filter(|m| **m).count();
if n_active < 3 {
return HashMap::new();
}
let active_freq: Array1<f64> = Array1::from(
freq.iter()
.zip(mask.iter())
.filter(|(_, m)| **m)
.map(|(f, _)| *f)
.collect::<Vec<_>>(),
);
let reference_spl = compute_reference_curve(curves, &mask, n_active);
let weights = compute_octave_weights(&active_freq);
let mut results: HashMap<String, SpectralAlignmentResult> = HashMap::new();
for (name, curve) in curves {
let channel_spl: Array1<f64> = Array1::from(
curve
.spl
.iter()
.zip(mask.iter())
.filter(|(_, m)| **m)
.map(|(s, _)| *s)
.collect::<Vec<_>>(),
);
let diff = &channel_spl - &reference_spl;
let (ls_fit, hs_fit, flat_fit, residual_rms) =
fit_shelf_gain_iterative(&diff, &active_freq, sample_rate, &weights);
let ls_gain = (-ls_fit).clamp(-MAX_SHELF_GAIN_DB, MAX_SHELF_GAIN_DB);
let hs_gain = (-hs_fit).clamp(-MAX_SHELF_GAIN_DB, MAX_SHELF_GAIN_DB);
let flat_gain = (-flat_fit).clamp(-MAX_FLAT_GAIN_DB, MAX_FLAT_GAIN_DB);
results.insert(
name.clone(),
SpectralAlignmentResult {
lowshelf_gain_db: ls_gain,
highshelf_gain_db: hs_gain,
flat_gain_db: flat_gain,
residual_rms_db: residual_rms,
},
);
}
let mean_flat: f64 =
results.values().map(|r| r.flat_gain_db).sum::<f64>() / results.len() as f64;
for result in results.values_mut() {
result.flat_gain_db -= mean_flat;
}
for result in results.values_mut() {
if result.lowshelf_gain_db.abs() < MIN_CORRECTION_DB {
result.lowshelf_gain_db = 0.0;
}
if result.highshelf_gain_db.abs() < MIN_CORRECTION_DB {
result.highshelf_gain_db = 0.0;
}
if result.flat_gain_db.abs() < MIN_CORRECTION_DB {
result.flat_gain_db = 0.0;
}
}
results
}
pub fn compute_inter_channel_deviation(
final_curves: &HashMap<String, crate::Curve>,
f3_hz: f64,
) -> super::types::InterChannelDeviation {
use super::types::InterChannelDeviation;
let empty = InterChannelDeviation {
deviation_per_freq: Vec::new(),
midrange_rms_db: 0.0,
passband_rms_db: 0.0,
midrange_peak_db: 0.0,
midrange_peak_freq: 0.0,
};
if final_curves.len() <= 1 {
return empty;
}
let first_curve = match final_curves.values().next() {
Some(c) => c,
None => return empty,
};
let freq = &first_curve.freq;
let n = freq.len();
let normalized: Vec<(&String, Vec<f64>)> = final_curves
.iter()
.map(|(name, curve)| {
let mut sum = 0.0;
let mut count = 0usize;
for i in 0..curve.spl.len().min(n) {
let f = freq[i];
if f >= f3_hz && f <= 10000.0 {
sum += curve.spl[i];
count += 1;
}
}
let mean = if count > 0 { sum / count as f64 } else { 0.0 };
let norm_spl: Vec<f64> = curve.spl.iter().map(|&s| s - mean).collect();
(name, norm_spl)
})
.collect();
let mut deviation_per_freq = Vec::with_capacity(n);
let mut midrange_sum_sq = 0.0;
let mut midrange_count = 0usize;
let mut midrange_peak_db: f64 = 0.0;
let mut midrange_peak_freq: f64 = 0.0;
let mut passband_sum_sq = 0.0;
let mut passband_count = 0usize;
for i in 0..n {
let f = freq[i];
let mut min_spl = f64::INFINITY;
let mut max_spl = f64::NEG_INFINITY;
for (_name, spl) in &normalized {
if i < spl.len() {
min_spl = min_spl.min(spl[i]);
max_spl = max_spl.max(spl[i]);
}
}
let spread = max_spl - min_spl;
deviation_per_freq.push((f, spread));
if (200.0..=4000.0).contains(&f) {
midrange_sum_sq += spread * spread;
midrange_count += 1;
if spread > midrange_peak_db {
midrange_peak_db = spread;
midrange_peak_freq = f;
}
}
if f >= f3_hz && f <= 10000.0 {
passband_sum_sq += spread * spread;
passband_count += 1;
}
}
let midrange_rms = if midrange_count > 0 {
(midrange_sum_sq / midrange_count as f64).sqrt()
} else {
0.0
};
let passband_rms = if passband_count > 0 {
(passband_sum_sq / passband_count as f64).sqrt()
} else {
0.0
};
InterChannelDeviation {
deviation_per_freq,
midrange_rms_db: midrange_rms,
passband_rms_db: passband_rms,
midrange_peak_db,
midrange_peak_freq,
}
}
#[derive(Debug, Clone)]
pub struct ChannelMatchingResult {
pub channel_name: String,
pub filters: Vec<Biquad>,
pub plugin: Option<super::types::PluginConfigWrapper>,
}
pub fn correct_inter_channel_deviation(
final_curves: &HashMap<String, crate::Curve>,
f3_hz: f64,
max_filters: usize,
sample_rate: f64,
) -> Vec<ChannelMatchingResult> {
if final_curves.len() <= 1 || max_filters == 0 {
return Vec::new();
}
let first_curve = match final_curves.values().next() {
Some(c) => c,
None => return Vec::new(),
};
let freq = &first_curve.freq;
let n = freq.len();
let passband_means: HashMap<String, f64> = final_curves
.iter()
.map(|(name, curve)| {
let mut sum = 0.0;
let mut count = 0usize;
for i in 0..curve.spl.len().min(n) {
if freq[i] >= f3_hz && freq[i] <= 10000.0 {
sum += curve.spl[i];
count += 1;
}
}
let mean = if count > 0 { sum / count as f64 } else { 0.0 };
(name.clone(), mean)
})
.collect();
let mut reference = vec![0.0; n];
for (name, curve) in final_curves {
let mean = passband_means[name];
for (i, ref_val) in reference
.iter_mut()
.enumerate()
.take(n.min(curve.spl.len()))
{
*ref_val += (curve.spl[i] - mean) / final_curves.len() as f64;
}
}
let mut results = Vec::new();
for (name, curve) in final_curves {
let mean = passband_means[name];
let diff: Vec<f64> = (0..n.min(curve.spl.len()))
.map(|i| (curve.spl[i] - mean) - reference[i])
.collect();
let smoothed_diff = smooth_for_peak_finding(&diff, freq, n);
let mut peaks: Vec<(usize, f64)> = Vec::new(); for i in 1..smoothed_diff.len().saturating_sub(1) {
let f = freq[i];
if f < f3_hz || f > 10000.0 {
continue;
}
let abs_val = smoothed_diff[i].abs();
if abs_val < 1.0 {
continue; }
let is_peak = smoothed_diff[i].abs() >= smoothed_diff[i - 1].abs()
&& smoothed_diff[i].abs() >= smoothed_diff[i + 1].abs();
if is_peak {
peaks.push((i, smoothed_diff[i]));
}
}
peaks.sort_by(|a, b| b.1.abs().partial_cmp(&a.1.abs()).unwrap());
peaks.truncate(max_filters);
let mut selected: Vec<(usize, f64)> = Vec::new();
for &(idx, dev) in &peaks {
let f = freq[idx];
let too_close = selected.iter().any(|&(sidx, _)| {
let sf = freq[sidx];
(f / sf).abs().log2().abs() < 1.0 / 3.0
});
if !too_close {
selected.push((idx, dev));
}
}
let mut filters = Vec::new();
for &(idx, dev) in &selected {
let f = freq[idx];
let gain_db = -dev;
let q = estimate_correction_q(&smoothed_diff, freq, idx);
filters.push(Biquad::new(
math_audio_iir_fir::BiquadFilterType::Peak,
f,
sample_rate,
q,
gain_db,
));
}
let plugin = if filters.is_empty() {
None
} else {
Some(output::create_labeled_eq_plugin(
&filters,
"channel_matching",
))
};
results.push(ChannelMatchingResult {
channel_name: name.clone(),
filters,
plugin,
});
}
results
}
fn smooth_for_peak_finding(diff: &[f64], freq: &Array1<f64>, n: usize) -> Vec<f64> {
let mut smoothed = vec![0.0; n];
let octave_width = 1.0 / 3.0;
for i in 0..n {
let center = freq[i];
let lo = center / 2.0_f64.powf(octave_width / 2.0);
let hi = center * 2.0_f64.powf(octave_width / 2.0);
let mut sum = 0.0;
let mut count = 0;
for j in 0..n.min(diff.len()) {
if freq[j] >= lo && freq[j] <= hi {
sum += diff[j];
count += 1;
}
}
smoothed[i] = if count > 0 {
sum / count as f64
} else {
diff.get(i).copied().unwrap_or(0.0)
};
}
smoothed
}
fn estimate_correction_q(diff: &[f64], freq: &Array1<f64>, peak_idx: usize) -> f64 {
let peak_val = diff[peak_idx].abs();
let half_val = peak_val * 0.5;
let peak_freq = freq[peak_idx];
let mut lo_freq = peak_freq;
for i in (0..peak_idx).rev() {
if diff[i].abs() < half_val {
lo_freq = freq[i];
break;
}
}
let mut hi_freq = peak_freq;
for i in (peak_idx + 1)..diff.len().min(freq.len()) {
if diff[i].abs() < half_val {
hi_freq = freq[i];
break;
}
}
let bw = hi_freq - lo_freq;
if bw > 0.0 {
(peak_freq / bw).clamp(0.5, 8.0)
} else {
2.0 }
}
pub fn create_alignment_plugins(
result: &SpectralAlignmentResult,
sample_rate: f64,
) -> (Option<PluginConfigWrapper>, Option<PluginConfigWrapper>) {
let mut shelf_filters = Vec::new();
if result.lowshelf_gain_db.abs() >= MIN_CORRECTION_DB {
shelf_filters.push(Biquad::new(
BiquadFilterType::Lowshelf,
LOWSHELF_FREQ,
sample_rate,
DEFAULT_Q_HIGH_LOW_SHELF,
result.lowshelf_gain_db,
));
}
if result.highshelf_gain_db.abs() >= MIN_CORRECTION_DB {
shelf_filters.push(Biquad::new(
BiquadFilterType::Highshelf,
HIGHSHELF_FREQ,
sample_rate,
DEFAULT_Q_HIGH_LOW_SHELF,
result.highshelf_gain_db,
));
}
let eq_plugin = if shelf_filters.is_empty() {
None
} else {
Some(output::create_eq_plugin(&shelf_filters))
};
let gain_plugin = if result.flat_gain_db.abs() >= MIN_CORRECTION_DB {
Some(output::create_gain_plugin(result.flat_gain_db))
} else {
None
};
(eq_plugin, gain_plugin)
}
pub fn compute_target_alignment(
curve: &Curve,
target: &Curve,
min_freq: f64,
max_freq: f64,
sample_rate: f64,
) -> Option<SpectralAlignmentResult> {
let freq = &curve.freq;
let mask: Vec<bool> = freq
.iter()
.map(|&f| f >= min_freq && f <= max_freq)
.collect();
let n_active: usize = mask.iter().filter(|m| **m).count();
if n_active < 3 {
return None;
}
let active_freq: Array1<f64> = Array1::from(
freq.iter()
.zip(mask.iter())
.filter(|(_, m)| **m)
.map(|(f, _)| *f)
.collect::<Vec<_>>(),
);
let channel_spl: Array1<f64> = Array1::from(
curve
.spl
.iter()
.zip(mask.iter())
.filter(|(_, m)| **m)
.map(|(s, _)| *s)
.collect::<Vec<_>>(),
);
let target_spl: Array1<f64> = Array1::from(
target
.spl
.iter()
.zip(mask.iter())
.filter(|(_, m)| **m)
.map(|(s, _)| *s)
.collect::<Vec<_>>(),
);
let diff = &channel_spl - &target_spl;
let weights = compute_octave_weights(&active_freq);
let (ls_fit, hs_fit, flat_fit, residual_rms) =
fit_shelf_gain_iterative(&diff, &active_freq, sample_rate, &weights);
let ls_gain = (-ls_fit).clamp(-MAX_SHELF_GAIN_DB, MAX_SHELF_GAIN_DB);
let hs_gain = (-hs_fit).clamp(-MAX_SHELF_GAIN_DB, MAX_SHELF_GAIN_DB);
let flat_gain = (-flat_fit).clamp(-MAX_FLAT_GAIN_DB, MAX_FLAT_GAIN_DB);
if ls_gain.abs() < MIN_CORRECTION_DB
&& hs_gain.abs() < MIN_CORRECTION_DB
&& flat_gain.abs() < MIN_CORRECTION_DB
{
return None;
}
Some(SpectralAlignmentResult {
lowshelf_gain_db: ls_gain,
highshelf_gain_db: hs_gain,
flat_gain_db: flat_gain,
residual_rms_db: residual_rms,
})
}
fn compute_reference_curve(
curves: &HashMap<String, Curve>,
mask: &[bool],
n_active: usize,
) -> Array1<f64> {
let n_channels = curves.len() as f64;
let mut sum = Array1::zeros(n_active);
for curve in curves.values() {
let active_spl: Vec<f64> = curve
.spl
.iter()
.zip(mask.iter())
.filter(|(_, m)| **m)
.map(|(s, _)| *s)
.collect();
sum += &Array1::from(active_spl);
}
sum / n_channels
}
fn build_basis_vectors(freq: &Array1<f64>, sample_rate: f64) -> (Array1<f64>, Array1<f64>) {
let ls = Biquad::new(
BiquadFilterType::Lowshelf,
LOWSHELF_FREQ,
sample_rate,
DEFAULT_Q_HIGH_LOW_SHELF,
1.0, );
let hs = Biquad::new(
BiquadFilterType::Highshelf,
HIGHSHELF_FREQ,
sample_rate,
DEFAULT_Q_HIGH_LOW_SHELF,
1.0, );
(ls.np_log_result(freq), hs.np_log_result(freq))
}
fn compute_octave_weights(freq: &Array1<f64>) -> Array1<f64> {
let n = freq.len();
let mut weights = Array1::zeros(n);
let log2_freq: Vec<f64> = freq.iter().map(|&f| f.log2()).collect();
for i in 1..n - 1 {
weights[i] = (log2_freq[i + 1] - log2_freq[i - 1]) / 2.0;
}
if n >= 2 {
weights[0] = log2_freq[1] - log2_freq[0];
weights[n - 1] = log2_freq[n - 1] - log2_freq[n - 2];
}
let total: f64 = weights.sum();
if total > 0.0 {
weights *= n as f64 / total;
}
weights
}
fn evaluate_shelf_response(
freq: &Array1<f64>,
sample_rate: f64,
ls_gain: f64,
hs_gain: f64,
flat_gain: f64,
) -> Array1<f64> {
let n = freq.len();
let mut response = Array1::from_elem(n, flat_gain);
if ls_gain.abs() > 1e-12 {
let ls = Biquad::new(
BiquadFilterType::Lowshelf,
LOWSHELF_FREQ,
sample_rate,
DEFAULT_Q_HIGH_LOW_SHELF,
ls_gain,
);
response += &ls.np_log_result(freq);
}
if hs_gain.abs() > 1e-12 {
let hs = Biquad::new(
BiquadFilterType::Highshelf,
HIGHSHELF_FREQ,
sample_rate,
DEFAULT_Q_HIGH_LOW_SHELF,
hs_gain,
);
response += &hs.np_log_result(freq);
}
response
}
fn fit_shelf_gain_iterative(
diff: &Array1<f64>,
freq: &Array1<f64>,
sample_rate: f64,
weights: &Array1<f64>,
) -> (f64, f64, f64, f64) {
let n = freq.len();
let flat_basis = Array1::ones(n);
let (ls_basis, hs_basis) = build_basis_vectors(freq, sample_rate);
let (ls_init, hs_init, flat_init, _) =
solve_3x3_wls(diff, &ls_basis, &hs_basis, &flat_basis, weights);
let x0 = if ls_init.is_finite() && hs_init.is_finite() && flat_init.is_finite() {
ndarray::array![ls_init, hs_init, flat_init]
} else {
ndarray::array![0.0, 0.0, 0.0]
};
let diff = diff.clone();
let freq = freq.clone();
let weights = weights.clone();
let residual_fn = |x: &Array1<f64>| -> Array1<f64> {
let response = evaluate_shelf_response(&freq, sample_rate, x[0], x[1], x[2]);
let r = &diff - &response;
&r * &weights.mapv(f64::sqrt)
};
let bounds = [
(-MAX_SHELF_GAIN_DB * 2.0, MAX_SHELF_GAIN_DB * 2.0), (-MAX_SHELF_GAIN_DB * 2.0, MAX_SHELF_GAIN_DB * 2.0), (-MAX_FLAT_GAIN_DB, MAX_FLAT_GAIN_DB), ];
let config = LMConfigBuilder::new()
.x0(x0)
.maxiter(10)
.tol(1e-10)
.jacobian_epsilon(0.1)
.build();
let report = match levenberg_marquardt(&residual_fn, &bounds, config) {
Ok(r) => r,
Err(_) => {
let ls = if ls_init.is_finite() {
ls_init.clamp(-MAX_SHELF_GAIN_DB * 2.0, MAX_SHELF_GAIN_DB * 2.0)
} else {
0.0
};
let hs = if hs_init.is_finite() {
hs_init.clamp(-MAX_SHELF_GAIN_DB * 2.0, MAX_SHELF_GAIN_DB * 2.0)
} else {
0.0
};
let flat = if flat_init.is_finite() {
flat_init.clamp(-MAX_FLAT_GAIN_DB, MAX_FLAT_GAIN_DB)
} else {
0.0
};
let actual = evaluate_shelf_response(&freq, sample_rate, ls, hs, flat);
let residual = &diff - &actual;
let rms = (residual
.iter()
.zip(weights.iter())
.map(|(&r, &w)| w * r * r)
.sum::<f64>()
/ n as f64)
.sqrt();
return (ls, hs, flat, rms);
}
};
let actual = evaluate_shelf_response(&freq, sample_rate, report.x[0], report.x[1], report.x[2]);
let residual = &diff - &actual;
let weighted_sq: f64 = residual
.iter()
.zip(weights.iter())
.map(|(&r, &w)| w * r * r)
.sum();
let residual_rms = (weighted_sq / n as f64).sqrt();
(report.x[0], report.x[1], report.x[2], residual_rms)
}
fn solve_3x3_wls(
diff: &Array1<f64>,
ls_basis: &Array1<f64>,
hs_basis: &Array1<f64>,
flat_basis: &Array1<f64>,
weights: &Array1<f64>,
) -> (f64, f64, f64, f64) {
let n = diff.len();
let wls = weights * ls_basis;
let whs = weights * hs_basis;
let w1 = weights * flat_basis;
let a00 = ls_basis.dot(&wls);
let a01 = ls_basis.dot(&whs);
let a02 = ls_basis.dot(&w1);
let a11 = hs_basis.dot(&whs);
let a12 = hs_basis.dot(&w1);
let a22 = flat_basis.dot(&w1);
let wd = weights * diff;
let b0 = ls_basis.dot(&wd);
let b1 = hs_basis.dot(&wd);
let b2 = flat_basis.dot(&wd);
let det = a00 * (a11 * a22 - a12 * a12) - a01 * (a01 * a22 - a12 * a02)
+ a02 * (a01 * a12 - a11 * a02);
if det.abs() < 1e-30 {
let flat_gain = if a22.abs() > 1e-30 { b2 / a22 } else { 0.0 };
return (0.0, 0.0, flat_gain, 0.0);
}
let inv_det = 1.0 / det;
let x0 = ((a11 * a22 - a12 * a12) * b0
+ (a02 * a12 - a01 * a22) * b1
+ (a01 * a12 - a02 * a11) * b2)
* inv_det;
let x1 = ((a02 * a12 - a01 * a22) * b0
+ (a00 * a22 - a02 * a02) * b1
+ (a01 * a02 - a00 * a12) * b2)
* inv_det;
let x2 = ((a01 * a12 - a02 * a11) * b0
+ (a01 * a02 - a00 * a12) * b1
+ (a00 * a11 - a01 * a01) * b2)
* inv_det;
let fitted = ls_basis * x0 + hs_basis * x1 + flat_basis * x2;
let residual = diff - &fitted;
let weighted_sq: f64 = residual
.iter()
.zip(weights.iter())
.map(|(&r, &w)| w * r * r)
.sum();
let residual_rms = (weighted_sq / n as f64).sqrt();
(x0, x1, x2, residual_rms)
}
pub fn log_spectral_alignment(results: &HashMap<String, SpectralAlignmentResult>) {
for (name, result) in results {
let has_shelves = result.lowshelf_gain_db.abs() >= MIN_CORRECTION_DB
|| result.highshelf_gain_db.abs() >= MIN_CORRECTION_DB;
let has_gain = result.flat_gain_db.abs() >= MIN_CORRECTION_DB;
if has_shelves || has_gain {
info!(
" Channel '{}': spectral alignment LS={:+.2} dB @ {} Hz, \
HS={:+.2} dB @ {} Hz, gain={:+.2} dB (residual {:.2} dB RMS)",
name,
result.lowshelf_gain_db,
LOWSHELF_FREQ,
result.highshelf_gain_db,
HIGHSHELF_FREQ,
result.flat_gain_db,
result.residual_rms_db,
);
} else {
info!(
" Channel '{}': no spectral alignment needed (residual {:.2} dB RMS)",
name, result.residual_rms_db,
);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
const SAMPLE_RATE: f64 = 48000.0;
fn make_curve(spl_fn: impl Fn(f64) -> f64) -> Curve {
let n = 200;
let log_start = 20f64.log10();
let log_end = 20000f64.log10();
let freq: Vec<f64> = (0..n)
.map(|i| 10f64.powf(log_start + (log_end - log_start) * i as f64 / (n - 1) as f64))
.collect();
let spl: Vec<f64> = freq.iter().map(|&f| spl_fn(f)).collect();
Curve {
freq: Array1::from(freq),
spl: Array1::from(spl),
phase: None,
}
}
fn make_narrow_curve(spl_fn: impl Fn(f64) -> f64, min_freq: f64, max_freq: f64) -> Curve {
let n = 50;
let log_start = min_freq.log10();
let log_end = max_freq.log10();
let freq: Vec<f64> = (0..n)
.map(|i| 10f64.powf(log_start + (log_end - log_start) * i as f64 / (n - 1) as f64))
.collect();
let spl: Vec<f64> = freq.iter().map(|&f| spl_fn(f)).collect();
Curve {
freq: Array1::from(freq),
spl: Array1::from(spl),
phase: None,
}
}
#[test]
fn test_flat_offset() {
let mut curves = HashMap::new();
curves.insert("L".to_string(), make_curve(|_| 2.0));
curves.insert("R".to_string(), make_curve(|_| 0.0));
let results = compute_spectral_alignment(&curves, SAMPLE_RATE, 20.0, 20000.0);
let l = &results["L"];
let r = &results["R"];
assert!(
l.lowshelf_gain_db.abs() < 0.3,
"L lowshelf should be ~0, got {}",
l.lowshelf_gain_db
);
assert!(
l.highshelf_gain_db.abs() < 0.3,
"L highshelf should be ~0, got {}",
l.highshelf_gain_db
);
assert!(
r.lowshelf_gain_db.abs() < 0.3,
"R lowshelf should be ~0, got {}",
r.lowshelf_gain_db
);
assert!(
r.highshelf_gain_db.abs() < 0.3,
"R highshelf should be ~0, got {}",
r.highshelf_gain_db
);
assert!(
(l.flat_gain_db + r.flat_gain_db).abs() < 0.01,
"flat gains should sum to 0"
);
assert!(
l.flat_gain_db < -0.5,
"L flat should be negative, got {}",
l.flat_gain_db
);
assert!(
r.flat_gain_db > 0.5,
"R flat should be positive, got {}",
r.flat_gain_db
);
}
#[test]
fn test_bass_tilt() {
let mut curves = HashMap::new();
curves.insert(
"L".to_string(),
make_curve(|f| if f < 200.0 { 3.0 } else { 0.0 }),
);
curves.insert("R".to_string(), make_curve(|_| 0.0));
let results = compute_spectral_alignment(&curves, SAMPLE_RATE, 20.0, 20000.0);
let l = &results["L"];
assert!(
l.lowshelf_gain_db < -0.3,
"L should need LS cut, got {}",
l.lowshelf_gain_db
);
assert!(
l.highshelf_gain_db.abs() < 1.5,
"L HS should be small, got {}",
l.highshelf_gain_db
);
}
#[test]
fn test_treble_tilt() {
let mut curves = HashMap::new();
curves.insert(
"L".to_string(),
make_curve(|f| if f > 4000.0 { 3.0 } else { 0.0 }),
);
curves.insert("R".to_string(), make_curve(|_| 0.0));
let results = compute_spectral_alignment(&curves, SAMPLE_RATE, 20.0, 20000.0);
let l = &results["L"];
assert!(
l.highshelf_gain_db < -0.3,
"L should need HS cut, got {}",
l.highshelf_gain_db
);
assert!(
l.lowshelf_gain_db.abs() < l.highshelf_gain_db.abs(),
"LS ({}) should be smaller than HS ({})",
l.lowshelf_gain_db,
l.highshelf_gain_db
);
}
#[test]
fn test_clamping() {
let mut curves = HashMap::new();
curves.insert(
"L".to_string(),
make_curve(|f| if f < 200.0 { 20.0 } else { 0.0 }),
);
curves.insert("R".to_string(), make_curve(|_| 0.0));
let results = compute_spectral_alignment(&curves, SAMPLE_RATE, 20.0, 20000.0);
for result in results.values() {
assert!(
result.lowshelf_gain_db.abs() <= MAX_SHELF_GAIN_DB + 0.01,
"LS gain {} exceeds max ±{}",
result.lowshelf_gain_db,
MAX_SHELF_GAIN_DB
);
assert!(
result.highshelf_gain_db.abs() <= MAX_SHELF_GAIN_DB + 0.01,
"HS gain {} exceeds max ±{}",
result.highshelf_gain_db,
MAX_SHELF_GAIN_DB
);
}
}
#[test]
fn test_single_channel() {
let mut curves = HashMap::new();
curves.insert("L".to_string(), make_curve(|_| 0.0));
let results = compute_spectral_alignment(&curves, SAMPLE_RATE, 20.0, 20000.0);
assert!(
results.is_empty(),
"Single channel should produce no alignment"
);
}
#[test]
fn test_solver_identity() {
let n = 100;
let freq = Array1::linspace(20.0, 20000.0, n);
let (ls_basis, hs_basis) = build_basis_vectors(&freq, SAMPLE_RATE);
let flat_basis = Array1::ones(n);
let weights = compute_octave_weights(&freq);
let diff = &ls_basis * 2.0 + &hs_basis * 3.0 + &flat_basis * 1.0;
let (ls, hs, flat, residual) =
solve_3x3_wls(&diff, &ls_basis, &hs_basis, &flat_basis, &weights);
assert!((ls - 2.0).abs() < 0.01, "LS should be 2.0, got {}", ls);
assert!((hs - 3.0).abs() < 0.01, "HS should be 3.0, got {}", hs);
assert!(
(flat - 1.0).abs() < 0.01,
"flat should be 1.0, got {}",
flat
);
assert!(residual < 0.01, "residual should be ~0, got {}", residual);
}
#[test]
fn test_create_alignment_plugins_shelves_and_gain() {
let result = SpectralAlignmentResult {
lowshelf_gain_db: -2.0,
highshelf_gain_db: 1.5,
flat_gain_db: -1.0,
residual_rms_db: 0.5,
};
let (eq, gain) = create_alignment_plugins(&result, SAMPLE_RATE);
assert!(eq.is_some(), "should have EQ plugin for shelves");
let eq = eq.unwrap();
assert_eq!(eq.plugin_type, "eq");
let filters = eq.parameters["filters"].as_array().unwrap();
assert_eq!(filters.len(), 2, "should have LS + HS");
assert!(gain.is_some(), "should have gain plugin");
let gain = gain.unwrap();
assert_eq!(gain.plugin_type, "gain");
}
#[test]
fn test_create_alignment_plugins_gain_only() {
let result = SpectralAlignmentResult {
lowshelf_gain_db: 0.0,
highshelf_gain_db: 0.0,
flat_gain_db: -2.0,
residual_rms_db: 0.3,
};
let (eq, gain) = create_alignment_plugins(&result, SAMPLE_RATE);
assert!(eq.is_none(), "no shelves → no EQ plugin");
assert!(gain.is_some(), "should have gain plugin");
}
#[test]
fn test_create_alignment_plugins_none() {
let result = SpectralAlignmentResult {
lowshelf_gain_db: 0.0,
highshelf_gain_db: 0.0,
flat_gain_db: 0.0,
residual_rms_db: 0.1,
};
let (eq, gain) = create_alignment_plugins(&result, SAMPLE_RATE);
assert!(eq.is_none());
assert!(gain.is_none());
}
#[test]
fn test_iterative_improves_large_gain_accuracy() {
let n = 200;
let log_start = 20f64.log10();
let log_end = 20000f64.log10();
let freq: Array1<f64> = Array1::from(
(0..n)
.map(|i| 10f64.powf(log_start + (log_end - log_start) * i as f64 / (n - 1) as f64))
.collect::<Vec<_>>(),
);
let true_ls = 5.0;
let true_hs = -4.0;
let true_flat = 1.0;
let diff = evaluate_shelf_response(&freq, SAMPLE_RATE, true_ls, true_hs, true_flat);
let weights = compute_octave_weights(&freq);
let (ls, hs, flat, residual) =
fit_shelf_gain_iterative(&diff, &freq, SAMPLE_RATE, &weights);
assert!(
(ls - true_ls).abs() < 0.05,
"LS should be {}, got {} (error {})",
true_ls,
ls,
(ls - true_ls).abs()
);
assert!(
(hs - true_hs).abs() < 0.05,
"HS should be {}, got {} (error {})",
true_hs,
hs,
(hs - true_hs).abs()
);
assert!(
(flat - true_flat).abs() < 0.05,
"flat should be {}, got {} (error {})",
true_flat,
flat,
(flat - true_flat).abs()
);
assert!(residual < 0.01, "residual should be ~0, got {}", residual);
let flat_basis = Array1::ones(n);
let (ls_basis, hs_basis) = build_basis_vectors(&freq, SAMPLE_RATE);
let (lin_ls, lin_hs, lin_flat, _) =
solve_3x3_wls(&diff, &ls_basis, &hs_basis, &flat_basis, &weights);
let lin_response = evaluate_shelf_response(&freq, SAMPLE_RATE, lin_ls, lin_hs, lin_flat);
let lin_residual_vec = &diff - &lin_response;
let lin_rms = (lin_residual_vec
.iter()
.zip(weights.iter())
.map(|(&r, &w)| w * r * r)
.sum::<f64>()
/ n as f64)
.sqrt();
assert!(
residual < lin_rms,
"Iterative residual ({:.4}) should be less than linear-only ({:.4})",
residual,
lin_rms
);
}
#[test]
fn test_three_channels() {
let mut curves = HashMap::new();
curves.insert(
"L".to_string(),
make_curve(|f| if f < 200.0 { 2.0 } else { 0.0 }),
);
curves.insert("C".to_string(), make_curve(|_| 0.0));
curves.insert(
"R".to_string(),
make_curve(|f| if f > 4000.0 { 2.0 } else { 0.0 }),
);
let results = compute_spectral_alignment(&curves, SAMPLE_RATE, 20.0, 20000.0);
assert_eq!(results.len(), 3);
let flat_sum: f64 = results.values().map(|r| r.flat_gain_db).sum();
assert!(
flat_sum.abs() < 0.1,
"flat gains should sum to ~0, got {}",
flat_sum
);
}
#[test]
fn test_narrow_band_no_divergence() {
let mut curves = HashMap::new();
curves.insert("L".to_string(), make_narrow_curve(|_| -30.0, 100.0, 400.0));
curves.insert("R".to_string(), make_narrow_curve(|_| -32.0, 100.0, 400.0));
let results = compute_spectral_alignment(&curves, SAMPLE_RATE, 100.0, 400.0);
for (name, r) in &results {
assert!(
r.flat_gain_db.abs() <= MAX_FLAT_GAIN_DB + 0.01,
"Channel '{}' flat_gain {:.2} dB exceeds ±{} dB",
name,
r.flat_gain_db,
MAX_FLAT_GAIN_DB
);
assert!(
r.flat_gain_db.is_finite(),
"Channel '{}' flat_gain is not finite",
name
);
assert!(
r.lowshelf_gain_db.is_finite(),
"Channel '{}' lowshelf_gain is not finite",
name
);
assert!(
r.highshelf_gain_db.is_finite(),
"Channel '{}' highshelf_gain is not finite",
name
);
}
}
#[test]
fn test_identical_channels_zero_correction() {
let mut curves = HashMap::new();
curves.insert("L".to_string(), make_curve(|_| 0.0));
curves.insert("R".to_string(), make_curve(|_| 0.0));
let results = compute_spectral_alignment(&curves, SAMPLE_RATE, 20.0, 20000.0);
for (name, r) in &results {
assert!(
r.flat_gain_db.abs() < MIN_CORRECTION_DB,
"Channel '{}' flat_gain should be ~0, got {:.4}",
name,
r.flat_gain_db
);
assert!(
r.lowshelf_gain_db.abs() < MIN_CORRECTION_DB,
"Channel '{}' lowshelf should be ~0, got {:.4}",
name,
r.lowshelf_gain_db
);
assert!(
r.highshelf_gain_db.abs() < MIN_CORRECTION_DB,
"Channel '{}' highshelf should be ~0, got {:.4}",
name,
r.highshelf_gain_db
);
}
}
#[test]
fn test_target_alignment_level_offset_must_not_cause_large_correction() {
let measurement = make_curve(|_| 5.0);
let target = make_curve(|f| 5.0 + (-0.8) * (f / 1000.0).log2());
let result = compute_target_alignment(&measurement, &target, 20.0, 20000.0, SAMPLE_RATE);
if let Some(r) = &result {
assert!(
r.flat_gain_db.abs() < 3.0,
"flat_gain should be small when target is level-aligned, got {:.2}dB",
r.flat_gain_db
);
}
let bad_target = make_curve(|f| (-0.8) * (f / 1000.0).log2());
let bad_result =
compute_target_alignment(&measurement, &bad_target, 20.0, 20000.0, SAMPLE_RATE);
if let Some(r) = &bad_result {
assert!(
r.flat_gain_db.abs() > 3.0,
"un-aligned target should produce large flat_gain, got {:.2}dB",
r.flat_gain_db
);
}
}
#[test]
fn test_target_alignment_same_level_flat() {
let mean_level = 7.0; let measurement = make_curve(|_| mean_level);
let target = make_curve(|_| mean_level);
let result = compute_target_alignment(&measurement, &target, 20.0, 20000.0, SAMPLE_RATE);
if let Some(r) = &result {
assert!(
r.flat_gain_db.abs() < MIN_CORRECTION_DB,
"flat_gain should be negligible, got {:.4}dB",
r.flat_gain_db
);
assert!(
r.lowshelf_gain_db.abs() < MIN_CORRECTION_DB,
"lowshelf should be negligible, got {:.4}dB",
r.lowshelf_gain_db
);
assert!(
r.highshelf_gain_db.abs() < MIN_CORRECTION_DB,
"highshelf should be negligible, got {:.4}dB",
r.highshelf_gain_db
);
}
}
#[test]
fn test_target_alignment_tilt_produces_shelf_not_flat() {
let mean_level = 5.0;
let measurement = make_curve(|_| mean_level);
let target = make_curve(|f| mean_level + (-0.8) * (f / 1000.0).log2());
let result = compute_target_alignment(&measurement, &target, 20.0, 20000.0, SAMPLE_RATE);
if let Some(r) = result {
assert!(
r.flat_gain_db.abs() < 2.0,
"flat_gain should be small for pure tilt, got {:.2}dB",
r.flat_gain_db
);
let has_shelf = r.lowshelf_gain_db.abs() > MIN_CORRECTION_DB
|| r.highshelf_gain_db.abs() > MIN_CORRECTION_DB;
assert!(has_shelf, "tilt should produce shelf corrections");
}
}
#[test]
fn test_broadband_must_use_flat_target_not_tilted() {
let measurement = make_curve(|_| 5.0);
let flat_target = make_curve(|_| 5.0);
let flat_result =
compute_target_alignment(&measurement, &flat_target, 20.0, 20000.0, SAMPLE_RATE);
let tilted_target = make_curve(|f| 5.0 + (-0.8) * (f / 1000.0).log2());
let tilted_result =
compute_target_alignment(&measurement, &tilted_target, 20.0, 20000.0, SAMPLE_RATE);
let flat_total = flat_result
.as_ref()
.map(|r| r.flat_gain_db.abs() + r.lowshelf_gain_db.abs() + r.highshelf_gain_db.abs())
.unwrap_or(0.0);
assert!(
flat_total < 1.0,
"flat measurement + flat target should need negligible correction, got {:.2}dB",
flat_total
);
if let Some(r) = &tilted_result {
let tilted_total =
r.flat_gain_db.abs() + r.lowshelf_gain_db.abs() + r.highshelf_gain_db.abs();
assert!(
tilted_total > 1.0,
"flat measurement + tilted target should produce shelf corrections, got {:.2}dB",
tilted_total
);
}
}
#[test]
fn test_broadband_corrections_are_gentle() {
let measurement = make_curve(|f| {
let peak = 10.0 * (-((f.log2() - 300.0_f64.log2()).powi(2)) / 0.5).exp();
5.0 + peak
});
let target = make_curve(|_| 5.0);
let result = compute_target_alignment(&measurement, &target, 20.0, 20000.0, SAMPLE_RATE);
if let Some(r) = result {
assert!(
r.lowshelf_gain_db.abs() <= MAX_SHELF_GAIN_DB + 0.01,
"lowshelf {:.2}dB exceeds limit {:.1}dB",
r.lowshelf_gain_db,
MAX_SHELF_GAIN_DB
);
assert!(
r.highshelf_gain_db.abs() <= MAX_SHELF_GAIN_DB + 0.01,
"highshelf {:.2}dB exceeds limit {:.1}dB",
r.highshelf_gain_db,
MAX_SHELF_GAIN_DB
);
assert!(
r.flat_gain_db.abs() <= MAX_FLAT_GAIN_DB + 0.01,
"flat_gain {:.2}dB exceeds limit {:.1}dB",
r.flat_gain_db,
MAX_FLAT_GAIN_DB
);
}
}
}