use rand::prelude::*;
use serde::{Deserialize, Serialize};
use super::{Budget, OptimizationResult, PerturbativeMetaheuristic, SearchSpace};
use crate::metaheuristics::budget::ConvergenceTracker;
use crate::metaheuristics::traits::TerminationReason;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GeneticAlgorithm {
pub population_size: usize,
pub crossover_prob: f64,
pub mutation_prob: Option<f64>,
pub sbx_eta: f64,
pub mutation_eta: f64,
pub tournament_size: usize,
#[serde(default)]
seed: Option<u64>,
#[serde(skip)]
population: Vec<Vec<f64>>,
#[serde(skip)]
fitness: Vec<f64>,
#[serde(skip)]
best_idx: usize,
#[serde(skip)]
history: Vec<f64>,
}
impl Default for GeneticAlgorithm {
fn default() -> Self {
Self {
population_size: 100,
crossover_prob: 0.9,
mutation_prob: None,
sbx_eta: 20.0,
mutation_eta: 20.0,
tournament_size: 2,
seed: None,
population: Vec::new(),
fitness: Vec::new(),
best_idx: 0,
history: Vec::new(),
}
}
}
impl GeneticAlgorithm {
#[must_use]
pub fn with_population_size(mut self, size: usize) -> Self {
self.population_size = size;
self
}
#[must_use]
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = Some(seed);
self
}
fn tournament_select(&self, rng: &mut impl Rng) -> usize {
let mut best = rng.random_range(0..self.population_size);
for _ in 1..self.tournament_size {
let candidate = rng.random_range(0..self.population_size);
if self.fitness[candidate] < self.fitness[best] {
best = candidate;
}
}
best
}
fn sbx_crossover(
&self,
p1: &[f64],
p2: &[f64],
lower: &[f64],
upper: &[f64],
rng: &mut impl Rng,
) -> (Vec<f64>, Vec<f64>) {
let dim = p1.len();
let mut c1 = p1.to_vec();
let mut c2 = p2.to_vec();
if rng.random::<f64>() < self.crossover_prob {
for i in 0..dim {
if rng.random::<f64>() < 0.5 {
let y1 = p1[i].min(p2[i]);
let y2 = p1[i].max(p2[i]);
if (y2 - y1).abs() > 1e-14 {
let beta = 1.0 + (2.0 * (y1 - lower[i]) / (y2 - y1));
let alpha = 2.0 - beta.powf(-(self.sbx_eta + 1.0));
let u: f64 = rng.random();
let betaq = if u <= 1.0 / alpha {
(u * alpha).powf(1.0 / (self.sbx_eta + 1.0))
} else {
(1.0 / (2.0 - u * alpha)).powf(1.0 / (self.sbx_eta + 1.0))
};
c1[i] = (0.5 * ((y1 + y2) - betaq * (y2 - y1))).clamp(lower[i], upper[i]);
c2[i] = (0.5 * ((y1 + y2) + betaq * (y2 - y1))).clamp(lower[i], upper[i]);
}
}
}
}
(c1, c2)
}
fn polynomial_mutate(
&self,
x: &mut [f64],
lower: &[f64],
upper: &[f64],
mut_prob: f64,
rng: &mut impl Rng,
) {
for i in 0..x.len() {
if rng.random::<f64>() < mut_prob {
let delta_max = upper[i] - lower[i];
let u: f64 = rng.random();
let delta = if u < 0.5 {
(2.0 * u).powf(1.0 / (self.mutation_eta + 1.0)) - 1.0
} else {
1.0 - (2.0 * (1.0 - u)).powf(1.0 / (self.mutation_eta + 1.0))
};
x[i] = (x[i] + delta * delta_max).clamp(lower[i], upper[i]);
}
}
}
}
impl PerturbativeMetaheuristic for GeneticAlgorithm {
type Solution = Vec<f64>;
fn optimize<F>(
&mut self,
objective: &F,
space: &SearchSpace,
budget: Budget,
) -> OptimizationResult<Self::Solution>
where
F: Fn(&[f64]) -> f64,
{
let mut rng: Box<dyn RngCore> = match self.seed {
Some(s) => Box::new(StdRng::seed_from_u64(s)),
None => Box::new(rand::rng()),
};
let (lower, upper, dim) = match space {
SearchSpace::Continuous { dim, lower, upper } => (lower.clone(), upper.clone(), *dim),
_ => panic!("GA requires continuous search space"),
};
let mut_prob = self.mutation_prob.unwrap_or(1.0 / dim as f64);
self.population = (0..self.population_size)
.map(|_| {
(0..dim)
.map(|j| rng.random_range(lower[j]..=upper[j]))
.collect()
})
.collect();
self.fitness = self.population.iter().map(|x| objective(x)).collect();
self.best_idx = self
.fitness
.iter()
.enumerate()
.min_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
.map_or(0, |(i, _)| i);
self.history.clear();
self.history.push(self.fitness[self.best_idx]);
let mut tracker = ConvergenceTracker::from_budget(&budget);
tracker.update(self.fitness[self.best_idx], self.population_size);
let max_iter = budget.max_iterations(self.population_size);
for _ in 0..max_iter {
let mut offspring = Vec::with_capacity(self.population_size);
let mut offspring_fit = Vec::with_capacity(self.population_size);
while offspring.len() < self.population_size {
let p1 = self.tournament_select(&mut rng);
let p2 = self.tournament_select(&mut rng);
let (mut c1, mut c2) = self.sbx_crossover(
&self.population[p1],
&self.population[p2],
&lower,
&upper,
&mut rng,
);
self.polynomial_mutate(&mut c1, &lower, &upper, mut_prob, &mut rng);
self.polynomial_mutate(&mut c2, &lower, &upper, mut_prob, &mut rng);
offspring_fit.push(objective(&c1));
offspring_fit.push(objective(&c2));
offspring.push(c1);
offspring.push(c2);
}
offspring.truncate(self.population_size);
offspring_fit.truncate(self.population_size);
self.population = offspring;
self.fitness = offspring_fit;
self.best_idx = self
.fitness
.iter()
.enumerate()
.min_by(|a, b| a.1.partial_cmp(b.1).unwrap_or(std::cmp::Ordering::Equal))
.map_or(0, |(i, _)| i);
self.history.push(self.fitness[self.best_idx]);
if !tracker.update(self.fitness[self.best_idx], self.population_size) {
break;
}
}
let termination = if tracker.is_converged() {
TerminationReason::Converged
} else if tracker.is_exhausted() {
TerminationReason::BudgetExhausted
} else {
TerminationReason::MaxIterations
};
OptimizationResult::new(
self.population[self.best_idx].clone(),
self.fitness[self.best_idx],
tracker.evaluations(),
self.history.len(),
self.history.clone(),
termination,
)
}
fn best(&self) -> Option<&Self::Solution> {
if self.population.is_empty() {
None
} else {
Some(&self.population[self.best_idx])
}
}
fn history(&self) -> &[f64] {
&self.history
}
fn reset(&mut self) {
self.population.clear();
self.fitness.clear();
self.best_idx = 0;
self.history.clear();
}
}
#[cfg(test)]
#[path = "tests_ga_contract.rs"]
mod tests_ga_contract;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ga_sphere() {
let objective = |x: &[f64]| x.iter().map(|xi| xi * xi).sum();
let mut ga = GeneticAlgorithm::default()
.with_seed(42)
.with_population_size(50);
let space = SearchSpace::continuous(5, -5.0, 5.0);
let result = ga.optimize(&objective, &space, Budget::Evaluations(5000));
assert!(result.objective_value < 5.0);
}
#[test]
fn test_ga_builder() {
let ga = GeneticAlgorithm::default()
.with_population_size(200)
.with_seed(999);
assert_eq!(ga.population_size, 200);
}
#[test]
fn test_ga_reset() {
let objective = |x: &[f64]| x.iter().sum::<f64>();
let mut ga = GeneticAlgorithm::default().with_seed(42);
let space = SearchSpace::continuous(2, -1.0, 1.0);
let _ = ga.optimize(&objective, &space, Budget::Evaluations(200));
assert!(ga.best().is_some());
ga.reset();
assert!(ga.best().is_none());
}
}