use crate::Curve;
use std::path::Path;
pub use math_audio_iir_fir::{
FirDesignConfig, FirPhase, WindowType,
generate_kirkeby_correction as generate_kirkeby_correction_raw, save_fir_to_wav,
};
pub fn generate_fir_from_response(
target_curve: &Curve,
sample_rate: f64,
n_taps: usize,
phase_type: FirPhase,
) -> Vec<f64> {
let config = FirDesignConfig {
n_taps,
sample_rate,
phase: phase_type,
..Default::default()
};
let freqs: Vec<f64> = target_curve.freq.to_vec();
let magnitude_db: Vec<f64> = target_curve.spl.to_vec();
math_audio_iir_fir::generate_fir_from_response(&freqs, &magnitude_db, &config)
}
pub fn generate_kirkeby_correction(
measurement: &Curve,
target: &Curve,
sample_rate: f64,
n_taps: usize,
min_freq: f64,
max_freq: f64,
) -> Vec<f64> {
generate_kirkeby_correction_with_phase(
measurement,
target,
sample_rate,
n_taps,
min_freq,
max_freq,
false, )
}
pub fn generate_kirkeby_correction_with_phase(
measurement: &Curve,
target: &Curve,
sample_rate: f64,
n_taps: usize,
min_freq: f64,
max_freq: f64,
correct_excess_phase: bool,
) -> Vec<f64> {
generate_kirkeby_correction_with_smoothing(
measurement,
target,
sample_rate,
n_taps,
min_freq,
max_freq,
correct_excess_phase,
0.167, )
}
pub fn generate_kirkeby_correction_with_smoothing(
measurement: &Curve,
target: &Curve,
sample_rate: f64,
n_taps: usize,
min_freq: f64,
max_freq: f64,
correct_excess_phase: bool,
phase_smoothing_octaves: f64,
) -> Vec<f64> {
let config = FirDesignConfig {
n_taps,
sample_rate,
phase: FirPhase::Kirkeby,
min_freq,
max_freq,
correct_excess_phase,
phase_smoothing_octaves,
..Default::default()
};
let meas_freqs: Vec<f64> = measurement.freq.to_vec();
let meas_db: Vec<f64> = measurement.spl.to_vec();
let meas_phase: Option<Vec<f64>> = measurement.phase.as_ref().map(|p| p.to_vec());
let target_db: Vec<f64> = if target.freq.len() == measurement.freq.len() {
target.spl.to_vec()
} else {
let interpolated = crate::read::interpolate(&measurement.freq, target);
interpolated.spl.to_vec()
};
generate_kirkeby_correction_raw(
&meas_freqs,
&meas_db,
meas_phase.as_deref(),
&target_db,
&config,
)
}
pub fn save_fir_wav(
coeffs: &[f64],
sample_rate: u32,
path: &Path,
) -> Result<(), Box<dyn std::error::Error>> {
save_fir_to_wav(coeffs, sample_rate, path)
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array1;
use tempfile::TempDir;
fn create_test_curve(freqs: &[f64], spl_values: &[f64]) -> Curve {
Curve {
freq: Array1::from(freqs.to_vec()),
spl: Array1::from(spl_values.to_vec()),
phase: None,
}
}
fn create_flat_curve(min_freq: f64, max_freq: f64, n_points: usize, spl_db: f64) -> Curve {
let freqs: Vec<f64> = (0..n_points)
.map(|i| {
let t = i as f64 / (n_points - 1) as f64;
min_freq * (max_freq / min_freq).powf(t)
})
.collect();
let spl: Vec<f64> = vec![spl_db; n_points];
create_test_curve(&freqs, &spl)
}
fn compute_energy_in_range(coeffs: &[f64], start_fraction: f64, end_fraction: f64) -> f64 {
let n = coeffs.len();
let start = (n as f64 * start_fraction) as usize;
let end = (n as f64 * end_fraction) as usize;
coeffs[start..end].iter().map(|x| x * x).sum()
}
#[test]
fn test_linear_phase_impulse_symmetry() {
let sample_rate = 48000.0;
let n_taps = 512;
let target_curve = create_test_curve(
&[20.0, 100.0, 1000.0, 5000.0, 20000.0],
&[0.0, 2.0, 0.0, -1.0, -2.0],
);
let coeffs =
generate_fir_from_response(&target_curve, sample_rate, n_taps, FirPhase::Linear);
assert_eq!(coeffs.len(), n_taps);
let (max_idx, _) = coeffs
.iter()
.enumerate()
.max_by(|a, b| a.1.abs().partial_cmp(&b.1.abs()).unwrap())
.unwrap();
let center = n_taps / 2;
let tolerance = n_taps / 10;
assert!(
(max_idx as isize - center as isize).unsigned_abs() < tolerance,
"Linear phase FIR peak should be near center. Peak at {}, center at {}",
max_idx,
center
);
}
#[test]
fn test_minimum_phase_energy_concentration() {
let sample_rate = 48000.0;
let n_taps = 1024;
let target_curve = create_test_curve(
&[20.0, 100.0, 500.0, 1000.0, 5000.0, 20000.0],
&[-3.0, 0.0, 2.0, 0.0, -2.0, -5.0],
);
let coeffs =
generate_fir_from_response(&target_curve, sample_rate, n_taps, FirPhase::Minimum);
assert_eq!(coeffs.len(), n_taps);
let first_half_energy = compute_energy_in_range(&coeffs, 0.0, 0.5);
let second_half_energy = compute_energy_in_range(&coeffs, 0.5, 1.0);
assert!(
first_half_energy > second_half_energy,
"Minimum phase should have more energy in first half: first={:.4}, second={:.4}",
first_half_energy,
second_half_energy
);
}
#[test]
fn test_flat_target_produces_near_impulse() {
let sample_rate = 48000.0;
let n_taps = 256;
let target_curve = create_flat_curve(20.0, 20000.0, 100, 0.0);
let coeffs =
generate_fir_from_response(&target_curve, sample_rate, n_taps, FirPhase::Linear);
assert_eq!(coeffs.len(), n_taps);
let (max_idx, max_val) = coeffs
.iter()
.enumerate()
.max_by(|a, b| a.1.abs().partial_cmp(&b.1.abs()).unwrap())
.unwrap();
let center = n_taps / 2;
assert!(
(max_idx as isize - center as isize).abs() < 10,
"Peak should be near center for linear phase"
);
assert!(*max_val > 0.0, "Peak coefficient should be positive");
}
#[test]
fn test_save_fir_to_wav_creates_valid_file() {
let temp_dir = TempDir::new().expect("Failed to create temp dir");
let wav_path = temp_dir.path().join("test_fir.wav");
let coeffs: Vec<f64> = (0..256).map(|i| (i as f64 * 0.01).sin()).collect();
let result = save_fir_wav(&coeffs, 48000, &wav_path);
assert!(result.is_ok(), "save_fir_wav should succeed");
assert!(wav_path.exists(), "WAV file should be created");
let reader = hound::WavReader::open(&wav_path).expect("Should open WAV file");
let spec = reader.spec();
assert_eq!(spec.channels, 1);
assert_eq!(spec.sample_rate, 48000);
assert_eq!(spec.bits_per_sample, 32);
assert_eq!(reader.len() as usize, coeffs.len());
}
#[test]
fn test_fir_phase_types_differ() {
let sample_rate = 48000.0;
let n_taps = 512;
let target_curve = create_test_curve(
&[20.0, 100.0, 1000.0, 10000.0, 20000.0],
&[0.0, 3.0, 0.0, -3.0, -6.0],
);
let linear_coeffs =
generate_fir_from_response(&target_curve, sample_rate, n_taps, FirPhase::Linear);
let minimum_coeffs =
generate_fir_from_response(&target_curve, sample_rate, n_taps, FirPhase::Minimum);
assert_eq!(linear_coeffs.len(), minimum_coeffs.len());
let sum_diff: f64 = linear_coeffs
.iter()
.zip(minimum_coeffs.iter())
.map(|(a, b)| (a - b).abs())
.sum();
assert!(
sum_diff > 0.1,
"Linear and minimum phase should produce different coefficients"
);
}
#[test]
fn test_kirkeby_correction() {
let measurement = create_test_curve(
&[20.0, 100.0, 500.0, 1000.0, 5000.0, 20000.0],
&[75.0, 82.0, 80.0, 78.0, 72.0, 65.0],
);
let target = create_test_curve(
&[20.0, 100.0, 500.0, 1000.0, 5000.0, 20000.0],
&[80.0, 80.0, 80.0, 80.0, 80.0, 80.0],
);
let coeffs =
generate_kirkeby_correction(&measurement, &target, 48000.0, 4096, 20.0, 1000.0);
assert_eq!(coeffs.len(), 4096);
assert!(coeffs.iter().any(|&x| x.abs() > 1e-10));
}
}