use super::{ObjectiveData, PenaltyMode, compute_fitness_penalties_ref};
use super::callback::{ProgressTracker, format_param_summary};
use ndarray::Array1;
#[allow(unused_imports)]
use metaheuristics_nature as mh;
#[allow(unused_imports)]
use mh::methods::{De as MhDe, Fa as MhFa, Pso as MhPso, Rga as MhRga, Tlbo as MhTlbo};
#[allow(unused_imports)]
use mh::{Bounded as MhBounded, Fitness as MhFitness, ObjFunc as MhObjFunc, Solver as MhSolver};
pub struct MHIntermediate {
pub x: Array1<f64>,
pub fun: f64,
pub iter: usize,
}
pub use crate::de::CallbackAction;
use std::sync::{Arc, Mutex};
#[derive(Clone)]
pub struct MHObjective {
pub data: ObjectiveData,
pub bounds: Vec<[f64; 2]>,
pub callback_state: Option<Arc<Mutex<CallbackState>>>,
}
pub struct CallbackState {
pub best_fitness: f64,
pub best_params: Vec<f64>,
pub eval_count: usize,
pub last_report_eval: usize,
}
impl MhBounded for MHObjective {
fn bound(&self) -> &[[f64; 2]] {
self.bounds.as_slice()
}
}
impl MhObjFunc for MHObjective {
type Ys = f64;
fn fitness(&self, xs: &[f64]) -> Self::Ys {
let fitness_val = compute_fitness_penalties_ref(xs, &self.data);
if let Some(ref state_arc) = self.callback_state
&& let Ok(mut state) = state_arc.lock()
{
state.eval_count += 1;
if fitness_val < state.best_fitness {
state.best_fitness = fitness_val;
state.best_params = xs.to_vec();
}
}
fitness_val
}
}
pub fn create_mh_callback(
algo_name: &str,
) -> Box<dyn FnMut(&MHIntermediate) -> CallbackAction + Send> {
let name = algo_name.to_string();
let mut tracker = ProgressTracker::default();
Box::new(move |intermediate: &MHIntermediate| -> CallbackAction {
let (improvement, _) = tracker.update(intermediate.fun);
if tracker.just_started_stalling()
|| tracker.stall_at_interval(25)
|| intermediate.iter.is_multiple_of(10)
{
crate::qa_println!(
"{} iter {:4} fitness={:.6e} {}",
name,
intermediate.iter,
intermediate.fun,
improvement
);
}
if intermediate.iter > 0 && intermediate.iter.is_multiple_of(50) {
let summary = format_param_summary(intermediate.x.as_slice().unwrap(), 3);
crate::qa_println!(" --> Best params: {}", summary);
}
CallbackAction::Continue
})
}
pub fn optimize_filters_mh(
x: &mut [f64],
lower_bounds: &[f64],
upper_bounds: &[f64],
objective_data: ObjectiveData,
mh_name: &str,
population: usize,
maxeval: usize,
) -> Result<(String, f64), (String, f64)> {
let callback = create_mh_callback(&format!("mh::{}", mh_name));
optimize_filters_mh_with_callback(
x,
lower_bounds,
upper_bounds,
objective_data,
mh_name,
population,
maxeval,
callback,
)
}
#[allow(clippy::too_many_arguments)]
pub fn optimize_filters_mh_with_callback(
x: &mut [f64],
lower_bounds: &[f64],
upper_bounds: &[f64],
objective_data: ObjectiveData,
mh_name: &str,
population: usize,
maxeval: usize,
mut callback: Box<dyn FnMut(&MHIntermediate) -> CallbackAction + Send>,
) -> Result<(String, f64), (String, f64)> {
let num_params = x.len();
assert_eq!(lower_bounds.len(), num_params);
assert_eq!(upper_bounds.len(), num_params);
let mut bounds: Vec<[f64; 2]> = Vec::with_capacity(num_params);
for i in 0..num_params {
bounds.push([lower_bounds[i], upper_bounds[i]]);
}
let mut penalty_data = objective_data.clone();
let penalty_mode = if mh_name == "pso" {
PenaltyMode::Pso
} else {
PenaltyMode::Standard
};
penalty_data.configure_penalties(penalty_mode);
let callback_state = Arc::new(Mutex::new(CallbackState {
best_fitness: f64::INFINITY,
best_params: vec![],
eval_count: 0,
last_report_eval: 0,
}));
let callback_state_task = Arc::clone(&callback_state);
let mh_obj = MHObjective {
data: penalty_data,
bounds,
callback_state: Some(Arc::clone(&callback_state)),
};
let builder = match mh_name {
"de" => MhSolver::build_boxed(MhDe::default(), mh_obj),
"pso" => {
let pso_tuned = MhPso::default()
.cognition(1.0) .social(1.5) .velocity(0.9); MhSolver::build_boxed(pso_tuned, mh_obj)
}
"rga" => {
MhSolver::build_boxed(MhRga::default(), mh_obj)
}
"tlbo" => MhSolver::build_boxed(MhTlbo, mh_obj),
"fa" | "firefly" => {
let fa_tuned = MhFa::default()
.alpha(0.5) .beta_min(1.0) .gamma(0.01); MhSolver::build_boxed(fa_tuned, mh_obj)
}
_ => MhSolver::build_boxed(MhDe::default(), mh_obj),
};
let pop = population.max(1);
let gens = (maxeval.max(pop)).div_ceil(pop);
let mut current_iter = 0_usize;
let report_interval = 100;
let solver = builder
.seed(0)
.pop_num(pop)
.task(move |_ctx| {
current_iter += 1;
if let Ok(mut state) = callback_state_task.lock() {
let evals_since_last = state.eval_count.saturating_sub(state.last_report_eval);
if evals_since_last >= report_interval {
let x_array = Array1::from(state.best_params.clone());
let intermediate = MHIntermediate {
x: x_array,
fun: state.best_fitness,
iter: current_iter,
};
let action = callback(&intermediate);
state.last_report_eval = state.eval_count;
if matches!(action, CallbackAction::Stop) {
return true; }
}
}
current_iter >= gens
})
.solve();
let best_xs = solver.as_best_xs();
if best_xs.len() == x.len() {
x.copy_from_slice(best_xs);
}
let best_val = *solver.as_best_fit();
Ok((format!("Metaheuristics({})", mh_name), best_val))
}