use std::collections::HashSet;
use rand::rngs::StdRng;
use crate::clustering::{self, ClusteringConfig};
use crate::constants::{
CONVERGENCE_THRESHOLD, CONVERGENCE_WINDOW, ENABLE_EARLY_STOPPING, NUM_OUTPUT_MODELS,
};
use crate::hall_of_fame::{HallOfFame, HallOfFameEntry};
use crate::population::Population;
use crate::structure::{combine_molecules, Molecule};
pub struct GaResult {
pub hall_of_fame: HallOfFame,
pub generations_run: u64,
pub converged_early: bool,
pub final_population: Population,
}
pub fn run_ga<F>(
mut pop: Population,
rng: &mut StdRng,
max_generations: u64,
mut on_generation: F,
) -> GaResult
where
F: FnMut(u64, &Population),
{
let mut hall_of_fame = HallOfFame::new();
let mut generation_count = 0u64;
let mut generations_without_improvement = 0u64;
let mut last_best_score = f64::MAX;
let mut converged_early = false;
while generation_count < max_generations {
pop.eval_fitness();
hall_of_fame.add_from_population(&pop.chromosomes);
let best_fitness = pop.get_min_fittest().fitness;
if generation_count > 0 {
let improvement = if last_best_score.abs() < f64::EPSILON {
0.0
} else {
(last_best_score - best_fitness) / last_best_score.abs()
};
if improvement < CONVERGENCE_THRESHOLD {
generations_without_improvement += 1;
} else {
generations_without_improvement = 0;
}
}
last_best_score = best_fitness;
on_generation(generation_count, &pop);
if ENABLE_EARLY_STOPPING && generations_without_improvement >= CONVERGENCE_WINDOW {
converged_early = true;
break;
}
generation_count += 1;
pop = pop.evolve(rng);
}
if !converged_early {
pop.eval_fitness();
hall_of_fame.add_from_population(&pop.chromosomes);
}
GaResult {
hall_of_fame,
generations_run: generation_count,
converged_early,
final_population: pop,
}
}
pub struct SelectedModels {
pub clustered: Vec<(usize, usize)>,
pub ranked: Vec<usize>,
}
pub fn select_models(
hof_entries: &[HallOfFameEntry],
receptor: &Molecule,
ligand: &Molecule,
) -> SelectedModels {
let complexes: Vec<Molecule> = hof_entries
.iter()
.map(|e| {
let docked = ligand
.clone()
.rotate(e.genes[0], e.genes[1], e.genes[2])
.displace(e.genes[3], e.genes[4], e.genes[5]);
combine_molecules(receptor, &docked)
})
.collect();
let cluster_config = ClusteringConfig::default();
let clusters = clustering::cluster_structures(&complexes, &cluster_config);
let mut cluster_centers: Vec<(usize, f64, usize)> = clusters
.iter()
.map(|c| (c.center_idx, hof_entries[c.center_idx].fitness, c.size))
.collect();
cluster_centers.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
let mut clustered: Vec<(usize, usize)> = cluster_centers
.iter()
.take(NUM_OUTPUT_MODELS)
.map(|(idx, _, size)| (*idx, *size))
.collect();
let mut used: HashSet<usize> = clustered.iter().map(|(i, _)| *i).collect();
if clustered.len() < NUM_OUTPUT_MODELS {
let mut remaining: Vec<(usize, f64)> = hof_entries
.iter()
.enumerate()
.filter(|(i, _)| !used.contains(i))
.map(|(i, e)| (i, e.fitness))
.collect();
remaining.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
for (idx, _) in remaining.iter().take(NUM_OUTPUT_MODELS - clustered.len()) {
clustered.push((*idx, 1));
used.insert(*idx);
}
}
let mut ranked_by_fitness: Vec<(usize, f64)> = hof_entries
.iter()
.enumerate()
.map(|(i, e)| (i, e.fitness))
.collect();
ranked_by_fitness.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
let ranked: Vec<usize> = ranked_by_fitness
.iter()
.take(NUM_OUTPUT_MODELS)
.map(|(idx, _)| *idx)
.collect();
SelectedModels { clustered, ranked }
}