use ndarray::Array1;
use rand::rngs::StdRng;
use rand::{Rng, SeedableRng};
#[derive(Debug, Clone)]
pub struct SmartInitConfig {
pub num_guesses: usize,
pub smoothing_sigma: f64,
pub min_peak_height: f64,
pub min_peak_distance: usize,
pub critical_frequencies: Vec<f64>,
pub variation_factor: f64,
pub seed: Option<u64>,
}
impl Default for SmartInitConfig {
fn default() -> Self {
Self {
num_guesses: 5,
smoothing_sigma: 2.0,
min_peak_height: 1.0,
min_peak_distance: 10,
critical_frequencies: vec![100.0, 300.0, 1000.0, 3000.0, 8000.0, 16000.0],
variation_factor: 0.1,
seed: None,
}
}
}
#[derive(Debug, Clone)]
struct FrequencyProblem {
frequency: f64,
magnitude: f64,
q_factor: f64,
}
fn smooth_problem_response(freq_grid: &Array1<f64>, data: &Array1<f64>, sigma: f64) -> Array1<f64> {
let curve = crate::Curve {
freq: freq_grid.clone(),
spl: data.clone(),
phase: None,
};
let bands_per_octave = sigma.max(1.0).round() as usize;
crate::read::smooth_one_over_n_octave(&curve, bands_per_octave).spl
}
fn find_peaks(data: &Array1<f64>, min_height: f64, min_distance: usize) -> Vec<usize> {
let mut peaks = Vec::new();
let n = data.len();
if n < 3 {
return peaks;
}
for i in 1..n - 1 {
if data[i] > data[i - 1] && data[i] > data[i + 1] && data[i] >= min_height {
if peaks.is_empty() || i - peaks[peaks.len() - 1] >= min_distance {
peaks.push(i);
}
}
}
peaks
}
pub fn create_smart_initial_guesses(
target_response: &Array1<f64>,
freq_grid: &Array1<f64>,
num_filters: usize,
bounds: &[(f64, f64)],
config: &SmartInitConfig,
peq_model: crate::cli::PeqModel,
) -> Vec<Vec<f64>> {
let mut main_rng: Box<dyn rand::RngCore> = if let Some(seed) = config.seed {
Box::new(StdRng::seed_from_u64(seed))
} else {
Box::new(rand::rng())
};
let smoothed = smooth_problem_response(freq_grid, target_response, config.smoothing_sigma);
let peaks = find_peaks(&smoothed, config.min_peak_height, config.min_peak_distance);
let inverted = -&smoothed;
let dips = find_peaks(&inverted, config.min_peak_height, config.min_peak_distance);
let mut problems = Vec::new();
for &peak_idx in &peaks {
if peak_idx < freq_grid.len() {
problems.push(FrequencyProblem {
frequency: freq_grid[peak_idx],
magnitude: -smoothed[peak_idx].abs(), q_factor: 1.0,
});
}
}
for &dip_idx in &dips {
if dip_idx < freq_grid.len() {
problems.push(FrequencyProblem {
frequency: freq_grid[dip_idx],
magnitude: smoothed[dip_idx].abs(), q_factor: 0.7, });
}
}
problems.sort_by(|a, b| {
b.magnitude
.abs()
.partial_cmp(&a.magnitude.abs())
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut initial_guesses = Vec::new();
let params_per_filter = crate::param_utils::params_per_filter(peq_model);
for _guess_idx in 0..config.num_guesses {
let mut guess = Vec::with_capacity(num_filters * params_per_filter);
let mut used_problems = problems.clone();
while used_problems.len() < num_filters {
for &critical_freq in &config.critical_frequencies {
if critical_freq >= freq_grid[0] && critical_freq <= freq_grid[freq_grid.len() - 1]
{
used_problems.push(FrequencyProblem {
frequency: critical_freq,
magnitude: 0.5,
q_factor: 1.0,
});
}
if used_problems.len() >= num_filters {
break;
}
}
while used_problems.len() < num_filters {
let rand_freq = main_rng.random_range(freq_grid[0]..freq_grid[freq_grid.len() - 1]);
used_problems.push(FrequencyProblem {
frequency: rand_freq,
magnitude: main_rng.random_range(-2.0..2.0),
q_factor: 1.0,
});
}
}
for i in 0..num_filters {
let problem = &used_problems[i % used_problems.len()];
let freq_scale = if config.variation_factor > 0.0 {
1.0 + main_rng.random_range(-config.variation_factor..config.variation_factor)
} else {
1.0
};
let freq_var = problem.frequency * freq_scale;
let gain_var = problem.magnitude * (1.0 + main_rng.random_range(-0.2..0.2));
let q_var = problem.q_factor * (1.0 + main_rng.random_range(-0.3..0.3));
match peq_model {
crate::cli::PeqModel::Pk
| crate::cli::PeqModel::HpPk
| crate::cli::PeqModel::HpPkLp
| crate::cli::PeqModel::LsPk
| crate::cli::PeqModel::LsPkHs => {
let base_idx = i * 3;
let log_freq = freq_var
.log10()
.max(bounds[base_idx].0)
.min(bounds[base_idx].1);
let q_constrained = q_var
.max(bounds[base_idx + 1].0)
.min(bounds[base_idx + 1].1);
let gain_constrained = gain_var
.max(bounds[base_idx + 2].0)
.min(bounds[base_idx + 2].1);
guess.extend_from_slice(&[log_freq, q_constrained, gain_constrained]);
}
crate::cli::PeqModel::FreePkFree | crate::cli::PeqModel::Free => {
let base_idx = i * 4;
let filter_type = 0.0; let log_freq = freq_var
.log10()
.max(bounds[base_idx + 1].0)
.min(bounds[base_idx + 1].1);
let q_constrained = q_var
.max(bounds[base_idx + 2].0)
.min(bounds[base_idx + 2].1);
let gain_constrained = gain_var
.max(bounds[base_idx + 3].0)
.min(bounds[base_idx + 3].1);
guess.extend_from_slice(&[
filter_type,
log_freq,
q_constrained,
gain_constrained,
]);
}
}
}
initial_guesses.push(guess);
}
initial_guesses
}
pub fn generate_integrality_constraints(num_filters: usize, use_freq_indexing: bool) -> Vec<bool> {
let mut constraints = Vec::with_capacity(num_filters * 4);
for _i in 0..num_filters {
constraints.push(use_freq_indexing); constraints.push(false); constraints.push(false); }
constraints
}
#[cfg(test)]
mod tests {
use super::*;
use ndarray::Array1;
#[test]
fn test_generate_integrality_constraints() {
let constraints = generate_integrality_constraints(2, true);
assert_eq!(constraints.len(), 6);
assert!(constraints[0]); assert!(!constraints[1]); assert!(!constraints[2]);
assert!(constraints[3]); assert!(!constraints[4]); assert!(!constraints[5]);
let constraints_continuous = generate_integrality_constraints(2, false);
assert_eq!(constraints_continuous.len(), 6);
assert!(!constraints_continuous[0]); assert!(!constraints_continuous[1]); assert!(!constraints_continuous[2]); assert!(!constraints_continuous[3]); assert!(!constraints_continuous[4]); assert!(!constraints_continuous[5]); }
#[test]
fn test_create_smart_initial_guesses() {
use crate::cli::PeqModel;
let target_response = Array1::from(vec![0.0, 3.0, 0.0, -2.0, 0.0]);
let freq_grid = Array1::from(vec![100.0, 200.0, 400.0, 800.0, 1600.0]);
let bounds = vec![
(100.0_f64.log10(), 1600.0_f64.log10()), (0.5, 3.0), (-6.0, 6.0), ];
let config = SmartInitConfig::default();
let guesses = create_smart_initial_guesses(
&target_response,
&freq_grid,
1,
&bounds,
&config,
PeqModel::Pk,
);
assert_eq!(guesses.len(), config.num_guesses);
for guess in &guesses {
assert_eq!(guess.len(), 3); assert!(guess[0] >= bounds[0].0 && guess[0] <= bounds[0].1);
assert!(guess[1] >= bounds[1].0 && guess[1] <= bounds[1].1);
assert!(guess[2] >= bounds[2].0 && guess[2] <= bounds[2].1);
}
}
#[test]
fn test_create_smart_initial_guesses_stable_across_grid_density() {
use crate::cli::PeqModel;
let coarse_freq_grid = Array1::from(vec![100.0, 300.0, 1000.0, 3000.0, 10000.0]);
let coarse_response = Array1::from(vec![0.0, 0.0, 6.0, 0.0, 0.0]);
let dense_freq_grid = Array1::logspace(10.0, 100.0_f64.log10(), 10_000.0_f64.log10(), 81);
let dense_response = dense_freq_grid.mapv(|f| {
let distance_oct = (f / 1000.0).log2();
6.0 * (-0.5 * (distance_oct / 0.15).powi(2)).exp()
});
let bounds = vec![
(100.0_f64.log10(), 10_000.0_f64.log10()),
(0.5, 3.0),
(-6.0, 6.0),
];
let config = SmartInitConfig {
num_guesses: 1,
variation_factor: 0.0,
seed: Some(42),
..SmartInitConfig::default()
};
let coarse_guess = create_smart_initial_guesses(
&coarse_response,
&coarse_freq_grid,
1,
&bounds,
&config,
PeqModel::Pk,
)[0][0];
let dense_guess = create_smart_initial_guesses(
&dense_response,
&dense_freq_grid,
1,
&bounds,
&config,
PeqModel::Pk,
)[0][0];
let coarse_freq = 10.0_f64.powf(coarse_guess);
let dense_freq = 10.0_f64.powf(dense_guess);
assert!(coarse_freq > 700.0 && coarse_freq < 1400.0);
assert!(dense_freq > 700.0 && dense_freq < 1400.0);
assert!((dense_freq / coarse_freq).log2().abs() < 0.2);
}
}