use super::optim::{AlgorithmType, get_all_algorithms};
use crate::LossType;
use crate::de::Strategy;
use clap::{Parser, ValueEnum};
use std::fmt;
use std::path::PathBuf;
use std::process;
#[derive(Debug, Clone, Copy, PartialEq, Eq, ValueEnum)]
pub enum PeqModel {
#[value(name = "pk")]
Pk,
#[value(name = "hp-pk")]
HpPk,
#[value(name = "hp-pk-lp")]
HpPkLp,
#[value(name = "ls-pk")]
LsPk,
#[value(name = "ls-pk-hs")]
LsPkHs,
#[value(name = "free-pk-free")]
FreePkFree,
#[value(name = "free")]
Free,
}
impl fmt::Display for PeqModel {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
PeqModel::Pk => write!(f, "pk"),
PeqModel::HpPk => write!(f, "hp-pk"),
PeqModel::LsPk => write!(f, "ls-pk"),
PeqModel::HpPkLp => write!(f, "hp-pk-lp"),
PeqModel::LsPkHs => write!(f, "ls-pk-hs"),
PeqModel::FreePkFree => write!(f, "free-pk-free"),
PeqModel::Free => write!(f, "free"),
}
}
}
impl PeqModel {
pub fn all() -> Vec<Self> {
vec![
PeqModel::Pk,
PeqModel::HpPk,
PeqModel::LsPk,
PeqModel::HpPkLp,
PeqModel::LsPkHs,
PeqModel::FreePkFree,
PeqModel::Free,
]
}
pub fn description(&self) -> &'static str {
match self {
PeqModel::Pk => "All filters are peak/bell filters",
PeqModel::HpPk => "First filter is highpass, rest are peak filters",
PeqModel::LsPk => "First filter is low shelve, rest are peak filters",
PeqModel::HpPkLp => "First filter is highpass, last is lowpass, rest are peak filters",
PeqModel::LsPkHs => {
"First filter is low shelve, last is high shelve, rest are peak filters"
}
PeqModel::FreePkFree => {
"First and last filters can be any type, middle filters are peak"
}
PeqModel::Free => "All filters can be any type (peak, highpass, lowpass, shelf)",
}
}
}
#[derive(Parser, Debug, Clone)]
#[command(author, about, long_about = None)]
pub struct Args {
#[arg(short = 'n', long, default_value_t = 7)]
pub num_filters: usize,
#[arg(short, long)]
pub curve: Option<PathBuf>,
#[arg(short, long)]
pub target: Option<PathBuf>,
#[arg(short, long, default_value_t = 48000.0)]
pub sample_rate: f64,
#[arg(long, default_value_t = 3.0, value_parser = parse_nonnegative_f64)]
pub max_db: f64,
#[arg(long, default_value_t = 1.0, value_parser = parse_strictly_positive_f64)]
pub min_db: f64,
#[arg(long, default_value_t = 3.0)]
pub max_q: f64,
#[arg(long, default_value_t = 1.0)]
pub min_q: f64,
#[arg(long, default_value_t = 60.0)]
pub min_freq: f64,
#[arg(long, default_value_t = 16000.0)]
pub max_freq: f64,
#[arg(short, long)]
pub output: Option<PathBuf>,
#[arg(long)]
pub speaker: Option<String>,
#[arg(long)]
pub version: Option<String>,
#[arg(long)]
pub measurement: Option<String>,
#[arg(long, default_value = "Listening Window")]
pub curve_name: String,
#[arg(long, default_value = "nlopt:cobyla")]
pub algo: String,
#[arg(long, default_value_t = 300)]
pub population: usize,
#[arg(long, default_value_t = 2_000)]
pub maxeval: usize,
#[arg(long, default_value_t = false)]
pub refine: bool,
#[arg(long, default_value = "cobyla")]
pub local_algo: String,
#[arg(long, default_value_t = 0.2)]
pub min_spacing_oct: f64,
#[arg(long, default_value_t = 20.0)]
pub spacing_weight: f64,
#[arg(long, default_value_t = true)]
pub smooth: bool,
#[arg(long, default_value_t = 2)]
pub smooth_n: usize,
#[arg(long, value_enum, default_value_t = LossType::SpeakerFlat)]
pub loss: LossType,
#[arg(long, value_enum, default_value_t = PeqModel::Pk)]
pub peq_model: PeqModel,
#[arg(long, default_value_t = false)]
pub peq_model_list: bool,
#[arg(long, default_value_t = false)]
pub algo_list: bool,
#[arg(long, default_value_t = 1e-3)]
pub tolerance: f64,
#[arg(long, default_value_t = 1e-4)]
pub atolerance: f64,
#[arg(long, default_value_t = 0.9, value_parser = parse_recombination_probability)]
pub recombination: f64,
#[arg(long, default_value = "currenttobest1bin")]
pub strategy: String,
#[arg(long, default_value_t = false)]
pub strategy_list: bool,
#[arg(long, default_value_t = 0.9)]
pub adaptive_weight_f: f64,
#[arg(long, default_value_t = 0.9)]
pub adaptive_weight_cr: f64,
#[arg(long = "no-parallel", default_value_t = false)]
pub no_parallel: bool,
#[arg(long, default_value_t = 0)]
pub parallel_threads: usize,
#[arg(long)]
pub seed: Option<u64>,
#[arg(long, value_name = "THRESHOLD")]
pub qa: Option<f64>,
#[arg(long)]
pub driver1: Option<PathBuf>,
#[arg(long)]
pub driver2: Option<PathBuf>,
#[arg(long)]
pub driver3: Option<PathBuf>,
#[arg(long)]
pub driver4: Option<PathBuf>,
#[arg(long, default_value = "linkwitzriley4")]
pub crossover_type: String,
#[arg(long)]
pub preset: Option<String>,
}
impl Args {
pub fn effective_peq_model(&self) -> PeqModel {
self.peq_model
}
pub fn uses_highpass_first(&self) -> bool {
matches!(
self.effective_peq_model(),
PeqModel::HpPk | PeqModel::HpPkLp
)
}
pub fn speaker_defaults() -> Self {
Self {
num_filters: 5,
sample_rate: 48000.0,
loss: LossType::SpeakerFlat,
algo: "autoeq:de".to_string(),
population: 50,
maxeval: 2000,
strategy: "currenttobest1bin".to_string(),
min_db: -12.0,
max_db: 12.0,
min_q: 0.5,
max_q: 10.0,
min_freq: 20.0,
max_freq: 20000.0,
min_spacing_oct: 0.5,
spacing_weight: 20.0,
smooth: true,
smooth_n: 1,
refine: false,
local_algo: "cobyla".to_string(),
tolerance: 1e-3,
atolerance: 1e-4,
recombination: 0.9,
adaptive_weight_f: 0.8,
adaptive_weight_cr: 0.7,
peq_model: PeqModel::Pk,
curve_name: "Listening Window".to_string(),
curve: None,
target: None,
output: None,
speaker: None,
version: None,
measurement: None,
peq_model_list: false,
algo_list: false,
strategy_list: false,
no_parallel: false,
parallel_threads: 0,
seed: None,
qa: None,
driver1: None,
driver2: None,
driver3: None,
driver4: None,
crossover_type: "linkwitzriley4".to_string(),
preset: None,
}
}
pub fn headphone_defaults() -> Self {
Self {
loss: LossType::HeadphoneScore,
num_filters: 7,
..Self::speaker_defaults()
}
}
pub fn roomeq_defaults() -> Self {
Self {
num_filters: 10,
max_freq: 500.0, ..Self::speaker_defaults()
}
}
pub fn apply_preset(&mut self) {
let preset_name = match &self.preset {
Some(name) => name.clone(),
None => return,
};
match preset_name.as_str() {
"quick" => {
self.num_filters = 5;
self.population = 40;
self.maxeval = 2000;
self.refine = false;
self.min_q = 0.5;
self.max_q = 6.0;
self.min_db = -12.0;
self.max_db = 6.0;
}
"balanced" => {
self.num_filters = 7;
self.population = 80;
self.maxeval = 5000;
self.refine = true;
self.min_q = 0.5;
self.max_q = 6.0;
self.min_db = -12.0;
self.max_db = 6.0;
}
"max-quality" => {
self.num_filters = 10;
self.peq_model = PeqModel::LsPkHs;
self.population = 200;
self.maxeval = 20000;
self.refine = true;
self.min_q = 0.5;
self.max_q = 6.0;
self.min_db = -12.0;
self.max_db = 6.0;
}
"score" => {
self.num_filters = 7;
self.loss = LossType::SpeakerScore;
self.population = 100;
self.maxeval = 10000;
self.refine = true;
self.min_q = 0.5;
self.max_q = 6.0;
self.min_db = -12.0;
self.max_db = 6.0;
}
other => {
eprintln!(
"Warning: unknown preset '{}'. Available: quick, balanced, max-quality, score",
other
);
}
}
}
pub fn display_preset_list() -> ! {
println!("Available Presets");
println!("=================\n");
println!(" quick 5 filters, fast optimization (seconds)");
println!(" balanced 7 filters, good balance of quality and speed (default)");
println!(" max-quality 10 filters with shelves, best results (slow)");
println!(" score 7 filters, listener preference score optimization");
println!("\nUse: --preset <name>");
println!("Individual flags (e.g. -n, --max-q) override preset values.");
process::exit(0);
}
}
pub fn display_algorithm_list() -> ! {
println!("Available Optimization Algorithms");
println!("=================================\n");
let algorithms = get_all_algorithms();
let mut nlopt_algos = Vec::new();
let mut metaheuristics_algos = Vec::new();
let mut autoeq_algos = Vec::new();
for algo in &algorithms {
match algo.library {
"NLOPT" => nlopt_algos.push(algo),
"Metaheuristics" => metaheuristics_algos.push(algo),
"AutoEQ" => autoeq_algos.push(algo),
_ => {} }
}
if !nlopt_algos.is_empty() {
println!("📊 NLOPT Library Algorithms:");
let mut global = Vec::new();
let mut local = Vec::new();
for algo in nlopt_algos {
match algo.algorithm_type {
AlgorithmType::Global => global.push(algo),
AlgorithmType::Local => local.push(algo),
}
}
if !global.is_empty() {
println!(" 🌍 Global Optimizers (best for exploring solution space):");
for algo in global {
print!(" - {:<20}", algo.name);
print!(" | Constraints: ");
if algo.supports_nonlinear_constraints {
print!("✅ Nonlinear");
} else if algo.supports_linear_constraints {
print!("🔶 Linear only");
} else {
print!("❌ None");
}
let description = match algo.name {
"nlopt:isres" => {
" | Improved Stochastic Ranking Evolution Strategy (recommended)"
}
"nlopt:ags" => " | Adaptive Geometric Search",
"nlopt:origdirect" => " | DIRECT global optimization (original version)",
"nlopt:crs2lm" => " | Controlled Random Search with local mutation",
"nlopt:direct" => " | DIRECT global optimization",
"nlopt:directl" => " | DIRECT-L (locally biased version)",
"nlopt:gmlsl" => " | Global Multi-Level Single-Linkage",
"nlopt:gmlsllds" => " | GMLSL with low-discrepancy sequence",
"nlopt:stogo" => " | Stochastic Global Optimization",
"nlopt:stogorand" => " | StoGO with randomized search",
_ => "",
};
println!("{}", description);
}
println!();
}
if !local.is_empty() {
println!(" 🎯 Local Optimizers (fast refinement from good starting points):");
for algo in local {
print!(" - {:<20}", algo.name);
print!(" | Constraints: ");
if algo.supports_nonlinear_constraints {
print!("✅ Nonlinear");
} else if algo.supports_linear_constraints {
print!("🔶 Linear only");
} else {
print!("❌ None");
}
let description = match algo.name {
"nlopt:cobyla" => {
" | Constrained Optimization BY Linear Approximations (recommended for local)"
}
"nlopt:bobyqa" => " | Bound Optimization BY Quadratic Approximation",
"nlopt:neldermead" => " | Nelder-Mead simplex algorithm",
"nlopt:sbplx" => " | Subplex (variant of Nelder-Mead)",
"nlopt:slsqp" => " | Sequential Least SQuares Programming",
_ => "",
};
println!("{}", description);
}
println!();
}
}
if !metaheuristics_algos.is_empty() {
println!("🧬 Metaheuristics Library Algorithms:");
println!(" Nature-inspired global optimization (penalty-based constraints)\n");
for algo in metaheuristics_algos {
print!(" - {:<20}", algo.name);
let description = match algo.name {
"mh:de" => " | Differential Evolution (robust, good convergence)",
"mh:pso" => " | Particle Swarm Optimization (fast exploration)",
"mh:rga" => " | Real-coded Genetic Algorithm (diverse search)",
"mh:tlbo" => " | Teaching-Learning-Based Optimization (parameter-free)",
"mh:firefly" => " | Firefly Algorithm (multi-modal problems)",
_ => "",
};
println!("{}", description);
}
println!();
}
if !autoeq_algos.is_empty() {
println!("🎵 AutoEQ Custom Algorithms:");
println!(" Specialized algorithms developed for audio filter optimization\n");
for algo in autoeq_algos {
print!(" - {:<20}", algo.name);
print!(" | Constraints: ");
if algo.supports_nonlinear_constraints {
print!("✅ Nonlinear");
} else {
print!("❌ Penalty-based");
}
let description = match algo.name {
"autoeq:de" => " | Adaptive DE with constraint handling (experimental)",
_ => "",
};
println!("{}", description);
}
println!();
}
println!("Usage Examples:");
println!("==============\n");
println!(" # Use ISRES (recommended global optimizer):");
println!(" autoeq --algo nlopt:isres --curve input.csv\n");
println!(" # Use COBYLA (fast local optimizer):");
println!(" autoeq --algo nlopt:cobyla --curve input.csv\n");
println!(" # Use Differential Evolution from metaheuristics:");
println!(" autoeq --algo mh:de --curve input.csv\n");
println!(" # Backward compatibility (maps to nlopt:cobyla):");
println!(" autoeq --algo cobyla --curve input.csv\n");
println!("Recommendations:");
println!("===============\n");
println!(" 🎯 For best results: nlopt:isres (global) + --refine with nlopt:cobyla (local)");
println!(" ⚡ For speed: nlopt:cobyla (if you have a good initial guess)");
println!(" 🧪 For experimentation: mh:de or mh:pso from metaheuristics library");
println!(
" ⚖️ For constrained problems: Prefer algorithms with ✅ Nonlinear constraint support"
);
process::exit(0);
}
pub fn display_strategy_list() -> ! {
println!("Available Differential Evolution (DE) Strategies");
println!("===============================================\n");
let strategies = [
(
"best1bin",
"Best1Bin",
"Use best individual + 1 random difference (binomial crossover)",
"Global exploration with fast convergence",
),
(
"best1exp",
"Best1Exp",
"Use best individual + 1 random difference (exponential crossover)",
"Similar to best1bin with different crossover",
),
(
"rand1bin",
"Rand1Bin",
"Use random individual + 1 random difference (binomial crossover)",
"Good diversity, slower convergence",
),
(
"rand1exp",
"Rand1Exp",
"Use random individual + 1 random difference (exponential crossover)",
"Similar to rand1bin with different crossover",
),
(
"rand2bin",
"Rand2Bin",
"Use random individual + 2 random differences (binomial crossover)",
"High exploration, may be slower",
),
(
"rand2exp",
"Rand2Exp",
"Use random individual + 2 random differences (exponential crossover)",
"Similar to rand2bin with different crossover",
),
(
"currenttobest1bin",
"CurrentToBest1Bin",
"Blend current with best + random difference (binomial)",
"Balanced exploration/exploitation (recommended)",
),
(
"currenttobest1exp",
"CurrentToBest1Exp",
"Blend current with best + random difference (exponential)",
"Similar to currenttobest1bin",
),
(
"best2bin",
"Best2Bin",
"Use best individual + 2 random differences (binomial crossover)",
"Fast convergence, may get trapped locally",
),
(
"best2exp",
"Best2Exp",
"Use best individual + 2 random differences (exponential crossover)",
"Similar to best2bin",
),
(
"randtobest1bin",
"RandToBest1Bin",
"Blend random with best + random difference (binomial)",
"Good balance of diversity and convergence",
),
(
"randtobest1exp",
"RandToBest1Exp",
"Blend random with best + random difference (exponential)",
"Similar to randtobest1bin",
),
(
"adaptivebin",
"AdaptiveBin",
"Self-adaptive mutation with top-w% selection (binomial)",
"Advanced adaptive strategy (experimental)",
),
(
"adaptiveexp",
"AdaptiveExp",
"Self-adaptive mutation with top-w% selection (exponential)",
"Advanced adaptive strategy (experimental)",
),
];
println!("🎯 Classic DE Strategies (well-tested, reliable):");
for &(name, _enum_name, description, recommendation) in strategies.iter().take(12) {
if name.starts_with("adaptive") {
continue;
}
println!(" - {:<20} | {}", name, description);
println!(" {:<20} | 💡 {}", "", recommendation);
if name == "currenttobest1bin" {
println!(" {:<20} | ⭐ Recommended default strategy", "");
}
println!();
}
println!("🧬 Adaptive DE Strategies (experimental, research-based):");
for &(name, _enum_name, description, recommendation) in strategies.iter() {
if !name.starts_with("adaptive") {
continue;
}
println!(" - {:<20} | {}", name, description);
println!(" {:<20} | 💡 {}", "", recommendation);
println!(
" {:<20} | 🔧 Requires --adaptive-weight-f and --adaptive-weight-cr",
""
);
println!();
}
println!("Strategy Naming Conventions:");
println!("==========================\n");
println!(" • 'bin' = Binomial (uniform) crossover - each gene has equal probability");
println!(" • 'exp' = Exponential crossover - contiguous segments are more likely");
println!(" • Numbers (1, 2) indicate how many difference vectors are used\n");
println!("Usage Examples:");
println!("==============\n");
println!(" # Use recommended default strategy:");
println!(" autoeq --algo autoeq:de --strategy currenttobest1bin --curve input.csv\n");
println!(" # Use adaptive strategy with custom weights:");
println!(
" autoeq --algo autoeq:de --strategy adaptivebin --adaptive-weight-f 0.8 --adaptive-weight-cr 0.7\n"
);
println!(" # Use classic exploration strategy:");
println!(" autoeq --algo autoeq:de --strategy rand1bin --curve input.csv\n");
println!("Recommendations:");
println!("===============\n");
println!(
" ⭐ For general use: currenttobest1bin (good balance of exploration and exploitation)"
);
println!(" 🚀 For fast convergence: best1bin or best2bin (may get trapped in local optima)");
println!(" 🌍 For thorough exploration: rand1bin or rand2bin (slower but more robust)");
println!(
" 🧪 For research/experimentation: adaptivebin or adaptiveexp (requires parameter tuning)"
);
process::exit(0);
}
pub fn validate_args(args: &Args) -> Result<(), String> {
if args.algo == "autoeq:de" || args.algo.contains("de") {
use std::str::FromStr;
if let Err(err) = Strategy::from_str(&args.strategy) {
return Err(format!(
"Invalid DE strategy '{}': {}. Use --strategy-list to see available strategies.",
args.strategy, err
));
}
}
if crate::optim::find_algorithm_info(&args.algo).is_some() {
} else {
return Err(format!(
"Unknown algorithm: '{}'. Use --algo-list to see available algorithms.",
args.algo
));
}
if args.refine {
if crate::optim::find_algorithm_info(&args.local_algo).is_some() {
} else {
return Err(format!(
"Unknown local algorithm: '{}'. Use --algo-list to see available algorithms.",
args.local_algo
));
}
}
if args.min_q > args.max_q {
return Err(format!(
"Invalid Q factor range: min_q ({}) must be <= max_q ({})",
args.min_q, args.max_q
));
}
if args.min_freq > args.max_freq {
return Err(format!(
"Invalid frequency range: min_freq ({}) must be <= max_freq ({})",
args.min_freq, args.max_freq
));
}
if args.min_db > args.max_db {
return Err(format!(
"Invalid dB range: min_db ({}) must be <= max_db ({})",
args.min_db, args.max_db
));
}
if args.min_freq < 20.0 {
return Err(format!(
"Invalid min_freq: {} Hz. Must be >= 20 Hz (reasonable audio range)",
args.min_freq
));
}
if args.max_freq > 20000.0 {
return Err(format!(
"Invalid max_freq: {} Hz. Must be <= 20,000 Hz (reasonable audio range)",
args.max_freq
));
}
let nyquist = args.sample_rate / 2.0;
if args.max_freq > nyquist {
return Err(format!(
"max_freq ({:.0} Hz) exceeds Nyquist frequency ({:.0} Hz) at sample rate {:.0} Hz. \
Biquad filters above Nyquist have undefined behavior.",
args.max_freq, nyquist, args.sample_rate
));
}
if args.smooth_n < 1 || args.smooth_n > 24 {
return Err(format!(
"Invalid smooth_n: {}. Must be in range [1..24]",
args.smooth_n
));
}
if args.population == 0 {
return Err("Population size must be > 0".to_string());
}
if args.maxeval == 0 {
return Err("Maximum evaluations must be > 0".to_string());
}
if args.num_filters == 0 {
return Err("Number of filters must be > 0".to_string());
}
if args.num_filters > 50 {
return Err(format!(
"Number of filters ({}) is very high. Consider using <= 50 filters for reasonable performance",
args.num_filters
));
}
if args.tolerance <= 0.0 {
return Err("Tolerance must be > 0".to_string());
}
if args.atolerance < 0.0 {
return Err("Absolute tolerance must be >= 0".to_string());
}
if args.adaptive_weight_f < 0.0 || args.adaptive_weight_f > 1.0 {
return Err("Adaptive weight for F must be between 0.0 and 1.0".to_string());
}
if args.adaptive_weight_cr < 0.0 || args.adaptive_weight_cr > 1.0 {
return Err("Adaptive weight for CR must be between 0.0 and 1.0".to_string());
}
if args.loss == LossType::DriversFlat {
if args.driver1.is_none() || args.driver2.is_none() {
return Err("Multi-driver optimization requires at least --driver1 and --driver2 when using --loss drivers-flat".to_string());
}
let valid_crossover_types = ["butterworth2", "linkwitzriley2", "linkwitzriley4"];
if !valid_crossover_types.contains(&args.crossover_type.as_str()) {
return Err(format!(
"Invalid crossover type '{}'. Valid types: {}",
args.crossover_type,
valid_crossover_types.join(", ")
));
}
let n_drivers = [&args.driver1, &args.driver2, &args.driver3, &args.driver4]
.iter()
.filter(|d| d.is_some())
.count();
if !(2..=4).contains(&n_drivers) {
return Err(format!(
"Multi-driver optimization requires 2-4 drivers, got {}",
n_drivers
));
}
} else {
if args.driver1.is_some()
|| args.driver2.is_some()
|| args.driver3.is_some()
|| args.driver4.is_some()
{
return Err("Driver arguments (--driver1, --driver2, etc.) can only be used with --loss drivers-flat".to_string());
}
}
Ok(())
}
pub fn validate_args_or_exit(args: &Args) {
if let Err(error) = validate_args(args) {
eprintln!("❌ Validation Error: {}", error);
process::exit(1);
}
}
pub fn display_peq_model_list() -> ! {
println!("Available PEQ Models");
println!("===================");
println!();
println!("The PEQ model defines the structure and constraints of the equalizer filters.");
println!();
for model in PeqModel::all() {
println!(" --peq-model {}", model);
println!(" {}", model.description());
println!();
}
println!("Examples:");
println!(" autoeq --peq-model pk # All peak filters (default)");
println!(" autoeq --peq-model hp-pk # Highpass + peaks");
println!(" autoeq --peq-model hp-pk-lp # Highpass + peaks + lowpass");
process::exit(0);
}
fn parse_strictly_positive_f64(s: &str) -> Result<f64, String> {
let v: f64 = s.parse().map_err(|_| format!("invalid float: {s}"))?;
if v > 0.0 {
Ok(v)
} else {
Err("value must be strictly positive (> 0)".to_string())
}
}
fn parse_nonnegative_f64(s: &str) -> Result<f64, String> {
let v: f64 = s.parse().map_err(|_| format!("invalid float: {s}"))?;
if v >= 0.0 {
Ok(v)
} else {
Err("value must be non-negative (>= 0)".to_string())
}
}
fn parse_recombination_probability(s: &str) -> Result<f64, String> {
let v: f64 = s.parse().map_err(|_| format!("invalid float: {s}"))?;
if (0.0..=1.0).contains(&v) {
Ok(v)
} else {
Err("recombination probability must be between 0.0 and 1.0".to_string())
}
}