Skip to main content

graphmind_optimization/algorithms/
ga.rs

1use crate::common::{Individual, OptimizationResult, Problem, SolverConfig};
2use ndarray::Array1;
3use rand::prelude::*;
4
5pub struct GASolver {
6    pub config: SolverConfig,
7    pub crossover_rate: f64,
8    pub mutation_rate: f64,
9}
10
11impl GASolver {
12    pub fn new(config: SolverConfig) -> Self {
13        Self {
14            config,
15            crossover_rate: 0.8,
16            mutation_rate: 0.1,
17        }
18    }
19
20    pub fn solve<P: Problem>(&self, problem: &P) -> OptimizationResult {
21        let mut rng = thread_rng();
22        let dim = problem.dim();
23        let (lower, upper) = problem.bounds();
24
25        // 1. Initialize Population
26        let mut population: Vec<Individual> = (0..self.config.population_size)
27            .map(|_| {
28                let mut vars = Array1::zeros(dim);
29                for i in 0..dim {
30                    vars[i] = rng.gen_range(lower[i]..upper[i]);
31                }
32                let fitness = problem.fitness(&vars);
33                Individual::new(vars, fitness)
34            })
35            .collect();
36
37        let mut history = Vec::with_capacity(self.config.max_iterations);
38
39        for _iter in 0..self.config.max_iterations {
40            // Find best for history
41            let mut best_idx = 0;
42            for i in 1..population.len() {
43                if population[i].fitness < population[best_idx].fitness {
44                    best_idx = i;
45                }
46            }
47            history.push(population[best_idx].fitness);
48
49            // 2. Evolution
50            let mut new_population = Vec::with_capacity(self.config.population_size);
51
52            // Elitism: carry over the best
53            new_population.push(population[best_idx].clone());
54
55            while new_population.len() < self.config.population_size {
56                // Selection (Tournament)
57                let p1 = self.select(&population);
58                let p2 = self.select(&population);
59
60                // Crossover
61                let (mut c1_vars, mut c2_vars) = if rng.gen::<f64>() < self.crossover_rate {
62                    self.crossover(&p1.variables, &p2.variables)
63                } else {
64                    (p1.variables.clone(), p2.variables.clone())
65                };
66
67                // Mutation
68                self.mutate(&mut c1_vars, &lower, &upper);
69                self.mutate(&mut c2_vars, &lower, &upper);
70
71                // Add to new population
72                let f1 = problem.fitness(&c1_vars);
73                new_population.push(Individual::new(c1_vars, f1));
74
75                if new_population.len() < self.config.population_size {
76                    let f2 = problem.fitness(&c2_vars);
77                    new_population.push(Individual::new(c2_vars, f2));
78                }
79            }
80
81            population = new_population;
82        }
83
84        let mut best_idx = 0;
85        for i in 1..population.len() {
86            if population[i].fitness < population[best_idx].fitness {
87                best_idx = i;
88            }
89        }
90
91        OptimizationResult {
92            best_variables: population[best_idx].variables.clone(),
93            best_fitness: population[best_idx].fitness,
94            history,
95        }
96    }
97
98    fn select<'a>(&self, population: &'a [Individual]) -> &'a Individual {
99        let mut rng = thread_rng();
100        let i1 = rng.gen_range(0..population.len());
101        let i2 = rng.gen_range(0..population.len());
102
103        if population[i1].fitness < population[i2].fitness {
104            &population[i1]
105        } else {
106            &population[i2]
107        }
108    }
109
110    fn crossover(&self, p1: &Array1<f64>, p2: &Array1<f64>) -> (Array1<f64>, Array1<f64>) {
111        let mut rng = thread_rng();
112        let dim = p1.len();
113        let mut c1 = p1.clone();
114        let mut c2 = p2.clone();
115
116        // Uniform crossover
117        for i in 0..dim {
118            if rng.gen_bool(0.5) {
119                std::mem::swap(&mut c1[i], &mut c2[i]);
120            }
121        }
122        (c1, c2)
123    }
124
125    fn mutate(&self, vars: &mut Array1<f64>, lower: &Array1<f64>, upper: &Array1<f64>) {
126        let mut rng = thread_rng();
127        let dim = vars.len();
128
129        for i in 0..dim {
130            if rng.gen::<f64>() < self.mutation_rate {
131                // Small Gaussian mutation or random reset?
132                // Let's use Gaussian mutation for continuous space
133                let range = upper[i] - lower[i];
134                let delta = rand_distr::Normal::new(0.0, range * 0.1)
135                    .unwrap()
136                    .sample(&mut rng);
137                vars[i] = (vars[i] + delta).clamp(lower[i], upper[i]);
138            }
139        }
140    }
141}