use super::{Budget, ConvergenceTracker, OptimizationResult, SearchSpace, TerminationReason};
use crate::metaheuristics::traits::PerturbativeMetaheuristic;
use rand::prelude::*;
#[derive(Debug, Clone)]
pub struct BinaryGA {
pub population_size: usize,
pub tournament_size: usize,
pub crossover_prob: f64,
pub mutation_prob: f64,
pub elitism: usize,
seed: Option<u64>,
population: Vec<Vec<f64>>,
fitness: Vec<f64>,
best_idx: usize,
history: Vec<f64>,
}
impl Default for BinaryGA {
fn default() -> Self {
Self {
population_size: 100,
tournament_size: 3,
crossover_prob: 0.9,
mutation_prob: 0.01, elitism: 2,
seed: None,
population: Vec::new(),
fitness: Vec::new(),
best_idx: 0,
history: Vec::new(),
}
}
}
impl BinaryGA {
#[must_use]
pub fn with_population_size(mut self, size: usize) -> Self {
self.population_size = size.max(4);
self
}
#[must_use]
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = Some(seed);
self
}
#[must_use]
pub fn with_mutation_prob(mut self, prob: f64) -> Self {
self.mutation_prob = prob.clamp(0.0, 1.0);
self
}
#[must_use]
pub fn with_crossover_prob(mut self, prob: f64) -> Self {
self.crossover_prob = prob.clamp(0.0, 1.0);
self
}
fn tournament_select(&self, rng: &mut impl Rng) -> usize {
let mut best = rng.random_range(0..self.population_size);
for _ in 1..self.tournament_size {
let candidate = rng.random_range(0..self.population_size);
if self.fitness[candidate] < self.fitness[best] {
best = candidate;
}
}
best
}
fn uniform_crossover(p1: &[f64], p2: &[f64], rng: &mut impl Rng) -> (Vec<f64>, Vec<f64>) {
let dim = p1.len();
let mut c1 = Vec::with_capacity(dim);
let mut c2 = Vec::with_capacity(dim);
for i in 0..dim {
if rng.random::<f64>() < 0.5 {
c1.push(p1[i]);
c2.push(p2[i]);
} else {
c1.push(p2[i]);
c2.push(p1[i]);
}
}
(c1, c2)
}
fn mutate(&self, x: &mut [f64], rng: &mut impl Rng) {
for bit in x.iter_mut() {
if rng.random::<f64>() < self.mutation_prob {
*bit = if *bit > 0.5 { 0.0 } else { 1.0 };
}
}
}
#[must_use]
pub fn selected_features(solution: &[f64]) -> Vec<usize> {
solution
.iter()
.enumerate()
.filter(|(_, &b)| b > 0.5)
.map(|(i, _)| i)
.collect()
}
}
impl PerturbativeMetaheuristic for BinaryGA {
type Solution = Vec<f64>;
fn optimize<F>(
&mut self,
objective: &F,
space: &SearchSpace,
budget: Budget,
) -> OptimizationResult<Self::Solution>
where
F: Fn(&[f64]) -> f64,
{
let mut rng: Box<dyn RngCore> = match self.seed {
Some(s) => Box::new(StdRng::seed_from_u64(s)),
None => Box::new(rand::rng()),
};
let dim = match space {
SearchSpace::Binary { dim } | SearchSpace::Continuous { dim, .. } => *dim,
_ => panic!("BinaryGA requires Binary or Continuous search space"),
};
self.population = (0..self.population_size)
.map(|_| {
(0..dim)
.map(|_| if rng.random::<f64>() < 0.5 { 0.0 } else { 1.0 })
.collect()
})
.collect();
self.fitness = self.population.iter().map(|x| objective(x)).collect();
self.best_idx = self
.fitness
.iter()
.enumerate()
.min_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
.map_or(0, |(i, _)| i);
self.history.clear();
self.history.push(self.fitness[self.best_idx]);
let mut tracker = ConvergenceTracker::from_budget(&budget);
tracker.update(self.fitness[self.best_idx], self.population_size);
let max_iter = budget.max_iterations(self.population_size);
for _gen in 0..max_iter {
let mut indices: Vec<usize> = (0..self.population_size).collect();
indices.sort_by(|&a, &b| {
self.fitness[a]
.partial_cmp(&self.fitness[b])
.unwrap_or(std::cmp::Ordering::Equal)
});
let mut offspring = Vec::with_capacity(self.population_size);
let mut offspring_fit = Vec::with_capacity(self.population_size);
for &i in indices.iter().take(self.elitism) {
offspring.push(self.population[i].clone());
offspring_fit.push(self.fitness[i]);
}
while offspring.len() < self.population_size {
let p1 = self.tournament_select(&mut rng);
let p2 = self.tournament_select(&mut rng);
let (mut c1, mut c2) = if rng.random::<f64>() < self.crossover_prob {
Self::uniform_crossover(&self.population[p1], &self.population[p2], &mut rng)
} else {
(self.population[p1].clone(), self.population[p2].clone())
};
self.mutate(&mut c1, &mut rng);
self.mutate(&mut c2, &mut rng);
let f1 = objective(&c1);
let f2 = objective(&c2);
offspring.push(c1);
offspring_fit.push(f1);
if offspring.len() < self.population_size {
offspring.push(c2);
offspring_fit.push(f2);
}
}
self.population = offspring;
self.fitness = offspring_fit;
self.best_idx = self
.fitness
.iter()
.enumerate()
.min_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
.map_or(0, |(i, _)| i);
self.history.push(self.fitness[self.best_idx]);
if !tracker.update(self.fitness[self.best_idx], self.population_size) {
break;
}
}
let termination = if tracker.is_converged() {
TerminationReason::Converged
} else if tracker.is_exhausted() {
TerminationReason::BudgetExhausted
} else {
TerminationReason::MaxIterations
};
OptimizationResult {
solution: self.population[self.best_idx].clone(),
objective_value: self.fitness[self.best_idx],
evaluations: tracker.evaluations(),
iterations: self.history.len(),
history: self.history.clone(),
termination,
}
}
fn best(&self) -> Option<&Self::Solution> {
if self.population.is_empty() {
None
} else {
Some(&self.population[self.best_idx])
}
}
fn history(&self) -> &[f64] {
&self.history
}
fn reset(&mut self) {
self.population.clear();
self.fitness.clear();
self.best_idx = 0;
self.history.clear();
}
}
#[cfg(test)]
#[path = "binary_ga_tests.rs"]
mod tests;