use crate::Curve;
use ndarray::Array1;
use std::error::Error;
use super::types::{OptimizerConfig, TargetCurveConfig};
pub use crate::fir::FirPhase;
pub fn generate_fir_correction(
measurement: &Curve,
config: &OptimizerConfig,
target_config: Option<&TargetCurveConfig>,
sample_rate: f64,
) -> Result<Vec<f64>, Box<dyn Error>> {
let target_curve = match target_config {
Some(TargetCurveConfig::Path(path)) => {
let target = crate::read::read_curve_from_csv(path)?;
crate::read::normalize_and_interpolate_response(&measurement.freq, &target)
}
Some(TargetCurveConfig::Predefined(name)) => {
use crate::cli::Args;
use clap::Parser;
let dummy_args = Args::parse_from(["autoeq", "--curve-name", name]);
match crate::workflow::build_target_curve(&dummy_args, &measurement.freq, measurement) {
Ok(curve) => curve,
Err(_) => {
let target = crate::read::read_curve_from_csv(&std::path::PathBuf::from(name))?;
crate::read::normalize_and_interpolate_response(&measurement.freq, &target)
}
}
}
None => {
let min_freq = config.min_freq;
let max_freq = config.max_freq;
let mut sum = 0.0;
let mut count = 0;
for i in 0..measurement.freq.len() {
if measurement.freq[i] >= min_freq && measurement.freq[i] <= max_freq {
sum += measurement.spl[i];
count += 1;
}
}
let mean_level = if count > 0 { sum / count as f64 } else { 0.0 };
Curve {
freq: measurement.freq.clone(),
spl: Array1::from_elem(measurement.freq.len(), mean_level),
phase: None,
}
}
};
let fir_config = config.fir.as_ref().ok_or("FIR configuration missing")?;
let n_taps = fir_config.taps;
if fir_config.phase.to_lowercase() == "kirkeby" {
let coeffs = crate::fir::generate_kirkeby_correction_with_smoothing(
measurement,
&target_curve,
sample_rate,
n_taps,
config.min_freq,
config.max_freq,
fir_config.correct_excess_phase,
fir_config.phase_smoothing,
);
Ok(coeffs)
} else {
let correction_spl = &target_curve.spl - &measurement.spl;
let correction_curve = Curve {
freq: measurement.freq.clone(),
spl: correction_spl,
phase: None,
};
let phase_type = match fir_config.phase.to_lowercase().as_str() {
"linear" => FirPhase::Linear,
"minimum" => FirPhase::Minimum,
_ => return Err(format!("Unknown FIR phase type: {}", fir_config.phase).into()),
};
let pre_ringing =
fir_config
.pre_ringing
.as_ref()
.map(|pr| math_audio_iir_fir::PreRingingConfig {
threshold_db: pr.threshold_db,
max_time_s: pr.max_time_s,
});
let fir_design_config = math_audio_iir_fir::FirDesignConfig {
n_taps,
sample_rate,
phase: phase_type,
pre_ringing,
..Default::default()
};
let freqs: Vec<f64> = correction_curve.freq.to_vec();
let magnitude_db: Vec<f64> = correction_curve.spl.to_vec();
let coeffs = math_audio_iir_fir::generate_fir_from_response(
&freqs,
&magnitude_db,
&fir_design_config,
);
Ok(coeffs)
}
}
#[cfg(test)]
pub use math_audio_iir_fir::{WindowType, generate_window};
#[cfg(test)]
#[allow(clippy::field_reassign_with_default)]
mod tests {
use super::*;
use crate::roomeq::types::FirConfig;
use ndarray::Array1;
fn assert_approx_eq(a: f64, b: f64, epsilon: f64) {
assert!(
(a - b).abs() < epsilon,
"assertion failed: {} ≈ {} (diff = {}, epsilon = {})",
a,
b,
(a - b).abs(),
epsilon
);
}
fn create_test_curve(freqs: &[f64], spl_values: &[f64]) -> Curve {
Curve {
freq: Array1::from(freqs.to_vec()),
spl: Array1::from(spl_values.to_vec()),
phase: None,
}
}
fn create_test_curve_with_phase(freqs: &[f64], spl_values: &[f64], phase_deg: &[f64]) -> Curve {
Curve {
freq: Array1::from(freqs.to_vec()),
spl: Array1::from(spl_values.to_vec()),
phase: Some(Array1::from(phase_deg.to_vec())),
}
}
#[test]
fn test_hann_window_symmetry() {
let window = generate_window(8, WindowType::Hann, 0.0);
assert_approx_eq(window[0], window[7], 0.01);
assert_approx_eq(window[1], window[6], 0.01);
assert_approx_eq(window[2], window[5], 0.01);
assert_approx_eq(window[3], window[4], 0.01);
}
#[test]
fn test_hann_window_endpoints() {
let window = generate_window(128, WindowType::Hann, 0.0);
assert!(window[0] < 0.01);
assert!(window[127] < 0.01);
assert!(window[64] > 0.99);
}
#[test]
fn test_hamming_window_endpoints() {
let window = generate_window(128, WindowType::Hamming, 0.0);
assert!(window[0] > 0.07 && window[0] < 0.09);
assert!(window[64] > 0.99);
}
#[test]
fn test_blackman_window_endpoints() {
let window = generate_window(128, WindowType::Blackman, 0.0);
assert!(window[0] < 0.01);
assert!(window[64] > 0.99);
}
#[test]
fn test_kaiser_window_beta_0() {
let window = generate_window(8, WindowType::Kaiser, 0.0);
for w in window {
assert_approx_eq(w, 1.0, 0.01);
}
}
#[test]
fn test_rectangular_window() {
let window = generate_window(10, WindowType::Rectangular, 0.0);
assert_eq!(window.len(), 10);
for w in window {
assert_eq!(w, 1.0);
}
}
#[test]
fn test_kirkeby_with_phase_data() {
let freqs = vec![
20.0, 50.0, 100.0, 200.0, 500.0, 1000.0, 2000.0, 5000.0, 10000.0, 20000.0,
];
let spl = vec![75.0, 80.0, 85.0, 82.0, 80.0, 78.0, 76.0, 74.0, 70.0, 65.0];
let phase = vec![
-180.0, -120.0, -60.0, -30.0, 0.0, 30.0, 60.0, 90.0, 120.0, 150.0,
];
let measurement = create_test_curve_with_phase(&freqs, &spl, &phase);
let target = create_test_curve(
&[20.0, 100.0, 1000.0, 10000.0, 20000.0],
&[80.0, 80.0, 80.0, 80.0, 80.0],
);
let coeffs = crate::fir::generate_kirkeby_correction(
&measurement,
&target,
48000.0,
4096,
20.0,
1000.0,
);
assert_eq!(coeffs.len(), 4096);
assert!(coeffs.iter().any(|&x| x.abs() > 1e-10));
}
#[test]
fn test_kirkeby_without_phase_data() {
let measurement = create_test_curve(
&[20.0, 100.0, 500.0, 1000.0, 5000.0, 20000.0],
&[75.0, 82.0, 80.0, 78.0, 72.0, 65.0],
);
let target = create_test_curve(
&[20.0, 100.0, 1000.0, 10000.0, 20000.0],
&[80.0, 80.0, 80.0, 80.0, 80.0],
);
let coeffs = crate::fir::generate_kirkeby_correction(
&measurement,
&target,
48000.0,
4096,
20.0,
1000.0,
);
assert_eq!(coeffs.len(), 4096);
}
#[test]
fn test_generate_fir_correction_basic() {
let measurement = create_test_curve(
&[20.0, 100.0, 500.0, 1000.0, 5000.0, 20000.0],
&[78.0, 82.0, 80.0, 79.0, 75.0, 70.0],
);
let mut config = OptimizerConfig::default();
config.fir = Some(FirConfig {
taps: 1024,
phase: "linear".to_string(),
correct_excess_phase: false,
phase_smoothing: 0.167,
pre_ringing: None,
});
config.min_freq = 50.0;
config.max_freq = 2000.0;
let result = generate_fir_correction(&measurement, &config, None, 48000.0);
assert!(
result.is_ok(),
"FIR correction should succeed: {:?}",
result.err()
);
let coeffs = result.unwrap();
assert_eq!(coeffs.len(), 1024);
}
#[test]
fn test_generate_fir_correction_kirkeby_mode() {
let measurement = create_test_curve(
&[20.0, 100.0, 500.0, 1000.0, 5000.0, 20000.0],
&[78.0, 82.0, 80.0, 79.0, 75.0, 70.0],
);
let mut config = OptimizerConfig::default();
config.fir = Some(FirConfig {
taps: 2048,
phase: "kirkeby".to_string(),
correct_excess_phase: false,
phase_smoothing: 0.167,
pre_ringing: None,
});
config.min_freq = 20.0;
config.max_freq = 500.0;
let result = generate_fir_correction(&measurement, &config, None, 48000.0);
assert!(result.is_ok(), "Kirkeby FIR correction should succeed");
let coeffs = result.unwrap();
assert_eq!(coeffs.len(), 2048);
}
#[test]
fn test_fir_config_missing_returns_error() {
let measurement = create_test_curve(&[20.0, 1000.0, 20000.0], &[80.0, 80.0, 80.0]);
let config = OptimizerConfig::default();
let result = generate_fir_correction(&measurement, &config, None, 48000.0);
assert!(result.is_err(), "Should error when FIR config is missing");
let err = result.unwrap_err();
assert!(
err.to_string().contains("FIR configuration missing"),
"Error should mention missing FIR config"
);
}
#[test]
fn test_invalid_phase_type_returns_error() {
let measurement = create_test_curve(&[20.0, 1000.0, 20000.0], &[80.0, 80.0, 80.0]);
let mut config = OptimizerConfig::default();
config.fir = Some(FirConfig {
taps: 1024,
phase: "invalid_phase_type".to_string(),
correct_excess_phase: false,
phase_smoothing: 0.167,
pre_ringing: None,
});
let result = generate_fir_correction(&measurement, &config, None, 48000.0);
assert!(result.is_err(), "Should error on invalid phase type");
let err = result.unwrap_err();
assert!(
err.to_string().contains("Unknown FIR phase type"),
"Error should mention unknown phase type"
);
}
}