1use rand::Rng;
7
8use crate::{crossover, mutation, selection, Chromosome};
9
10#[derive(Debug, Clone)]
12pub struct GaConfig {
13 pub population_size: usize,
15 pub generations: usize,
17 pub mutation_rate: f64,
19 pub elite_count: usize,
21 pub parent_fraction: f64,
23}
24
25impl Default for GaConfig {
26 fn default() -> Self {
27 Self {
28 population_size: 100,
29 generations: 1000,
30 mutation_rate: 0.01,
31 elite_count: 2,
32 parent_fraction: 0.5,
33 }
34 }
35}
36
37impl GaConfig {
38 #[must_use]
40 pub fn population_size(mut self, n: usize) -> Self {
41 self.population_size = n;
42 self
43 }
44
45 #[must_use]
47 pub fn generations(mut self, n: usize) -> Self {
48 self.generations = n;
49 self
50 }
51
52 #[must_use]
54 pub fn mutation_rate(mut self, r: f64) -> Self {
55 self.mutation_rate = r;
56 self
57 }
58
59 #[must_use]
61 pub fn elite_count(mut self, n: usize) -> Self {
62 self.elite_count = n;
63 self
64 }
65
66 #[must_use]
68 pub fn parent_fraction(mut self, f: f64) -> Self {
69 self.parent_fraction = f;
70 self
71 }
72}
73
74#[derive(Debug)]
76pub struct Progress<'a> {
77 pub generation: usize,
79 pub population: &'a [Chromosome],
81}
82
83#[derive(Debug)]
85pub struct GaResult {
86 pub population: Vec<Chromosome>,
88 pub generations_run: usize,
90}
91
92pub fn run_ga<F, P, R>(
114 config: &GaConfig,
115 chromosome_size: usize,
116 mut evaluate: F,
117 mut on_progress: P,
118 rng: &mut R,
119) -> GaResult
120where
121 F: FnMut(&Chromosome) -> Vec<f64>,
122 P: FnMut(&Progress),
123 R: Rng,
124{
125 assert!(config.population_size > 0, "population_size must be > 0");
126 assert!(config.generations > 0, "generations must be > 0");
127
128 let mut population: Vec<Chromosome> = (0..config.population_size)
130 .map(|_| Chromosome::new(chromosome_size, rng))
131 .collect();
132
133 for generation in 0..config.generations {
134 for c in &mut population {
136 if c.fitness().is_none() {
137 let fitness = evaluate(c);
138 c.set_fitness(fitness);
139 }
140 }
141
142 let is_multi = population
144 .first()
145 .and_then(|c| c.fitness())
146 .is_some_and(|f| f.len() > 1);
147
148 let sorted: Vec<Chromosome> = if is_multi {
150 selection::nsga2(&population).collect()
151 } else {
152 selection::rank(&population).collect()
153 };
154
155 on_progress(&Progress {
157 generation,
158 population: &sorted,
159 });
160
161 if generation == config.generations - 1 {
163 return GaResult {
164 population: sorted,
165 generations_run: generation + 1,
166 };
167 }
168
169 let elite_count = if is_multi {
171 0
172 } else {
173 config.elite_count.min(config.population_size)
174 };
175 let pool_size =
176 ((config.population_size as f64 * config.parent_fraction) as usize).max(2);
177
178 let mut next_gen: Vec<Chromosome> = sorted[..elite_count].to_vec();
179
180 while next_gen.len() < config.population_size {
181 let p1 = &sorted[rng.random_range(0..pool_size)];
182 let p2 = &sorted[rng.random_range(0..pool_size)];
183
184 let (mut c1, mut c2) = crossover::single_point(p1, p2, rng);
185 mutation::bit_flip(&mut c1, config.mutation_rate, rng);
186 mutation::bit_flip(&mut c2, config.mutation_rate, rng);
187
188 next_gen.push(c1);
189 if next_gen.len() < config.population_size {
190 next_gen.push(c2);
191 }
192 }
193
194 population = next_gen;
195 }
196
197 GaResult {
198 population,
199 generations_run: config.generations,
200 }
201}
202
203#[cfg(test)]
204mod tests {
205 use rand::SeedableRng;
206
207 use super::*;
208
209 fn rng() -> rand::rngs::StdRng {
210 rand::rngs::StdRng::seed_from_u64(42)
211 }
212
213 #[test]
214 fn run_ga_single_objective() {
215 let config = GaConfig::default()
216 .population_size(20)
217 .generations(10)
218 .mutation_rate(0.01);
219
220 let mut last_gen = 0;
221 let result = run_ga(
222 &config,
223 4,
224 |c| {
225 let sum: u32 = c.data().iter().map(|&b| b as u32).sum();
227 vec![sum as f64]
228 },
229 |p| {
230 last_gen = p.generation;
231 },
232 &mut rng(),
233 );
234
235 assert_eq!(result.generations_run, 10);
236 assert_eq!(result.population.len(), 20);
237 assert_eq!(last_gen, 9); assert!(result.population[0].fitness().is_some());
239 }
240
241 #[test]
242 fn run_ga_multi_objective() {
243 let config = GaConfig::default()
244 .population_size(20)
245 .generations(5)
246 .mutation_rate(0.01);
247
248 let result = run_ga(
249 &config,
250 4,
251 |c| {
252 let sum: u32 = c.data().iter().map(|&b| b as u32).sum();
254 let max = c.data().iter().copied().max().unwrap_or(0) as f64;
255 vec![sum as f64, max]
256 },
257 |_| {},
258 &mut rng(),
259 );
260
261 assert_eq!(result.generations_run, 5);
262 assert_eq!(result.population[0].fitness().map(|f| f.len()), Some(2));
264 }
265
266 #[test]
267 fn config_builder() {
268 let config = GaConfig::default()
269 .population_size(50)
270 .generations(100)
271 .mutation_rate(0.05)
272 .elite_count(5)
273 .parent_fraction(0.3);
274
275 assert_eq!(config.population_size, 50);
276 assert_eq!(config.generations, 100);
277 assert!((config.mutation_rate - 0.05).abs() < f64::EPSILON);
278 assert_eq!(config.elite_count, 5);
279 assert!((config.parent_fraction - 0.3).abs() < f64::EPSILON);
280 }
281
282 #[test]
283 #[should_panic(expected = "population_size must be > 0")]
284 fn run_ga_panics_on_zero_population() {
285 let config = GaConfig::default().population_size(0);
286 run_ga(&config, 4, |_| vec![0.0], |_| {}, &mut rng());
287 }
288
289 #[test]
290 #[should_panic(expected = "generations must be > 0")]
291 fn run_ga_panics_on_zero_generations() {
292 let config = GaConfig::default().generations(0);
293 run_ga(&config, 4, |_| vec![0.0], |_| {}, &mut rng());
294 }
295}