1use crate::genome::{BehaviorDescriptor, Genome};
4use crate::population::{Individual, Population};
5use rand::Rng;
6use std::sync::{Arc, RwLock};
7
8pub trait Selection<G: Genome>: Send + Sync {
10 fn select<'a, R: Rng>(
12 &self,
13 population: &'a Population<G>,
14 rng: &mut R,
15 ) -> &'a Individual<G>;
16}
17
18pub struct Tournament {
20 pub size: usize,
22}
23
24impl Tournament {
25 pub fn new(size: usize) -> Self {
27 Self { size }
28 }
29}
30
31impl<G: Genome> Selection<G> for Tournament {
32 fn select<'a, R: Rng>(
33 &self,
34 population: &'a Population<G>,
35 rng: &mut R,
36 ) -> &'a Individual<G> {
37 let n = population.individuals.len();
38 let mut best: Option<&Individual<G>> = None;
39 let mut best_fitness = f64::NEG_INFINITY;
40
41 for _ in 0..self.size {
42 let idx = rng.gen_range(0..n);
43 let ind = &population.individuals[idx];
44 let fitness = ind.fitness_value();
45
46 if fitness > best_fitness {
47 best_fitness = fitness;
48 best = Some(ind);
49 }
50 }
51
52 best.unwrap_or(&population.individuals[0])
53 }
54}
55
56#[derive(Clone)]
58pub struct NoveltyArchive {
59 behaviors: Vec<BehaviorDescriptor>,
61 max_size: usize,
63 add_threshold: f64,
65}
66
67impl NoveltyArchive {
68 pub fn new(max_size: usize, add_threshold: f64) -> Self {
70 Self {
71 behaviors: Vec::new(),
72 max_size,
73 add_threshold,
74 }
75 }
76
77 pub fn add(&mut self, behavior: &BehaviorDescriptor, novelty: f64) -> bool {
79 if novelty >= self.add_threshold && self.behaviors.len() < self.max_size {
80 self.behaviors.push(behavior.clone());
81 true
82 } else {
83 false
84 }
85 }
86
87 pub fn force_add(&mut self, behavior: BehaviorDescriptor) {
89 if self.behaviors.len() < self.max_size {
90 self.behaviors.push(behavior);
91 }
92 }
93
94 pub fn behaviors(&self) -> &[BehaviorDescriptor] {
96 &self.behaviors
97 }
98
99 pub fn len(&self) -> usize {
101 self.behaviors.len()
102 }
103
104 pub fn is_empty(&self) -> bool {
106 self.behaviors.is_empty()
107 }
108
109 pub fn compute_novelty(&self, behavior: &BehaviorDescriptor, k: usize, population_behaviors: &[&BehaviorDescriptor]) -> f64 {
112 let all_behaviors: Vec<&BehaviorDescriptor> = self.behaviors
114 .iter()
115 .chain(population_behaviors.iter().copied())
116 .collect();
117
118 if all_behaviors.is_empty() {
119 return f64::MAX; }
121
122 let mut distances: Vec<f64> = all_behaviors
124 .iter()
125 .map(|other| behavior.distance(other))
126 .filter(|d| *d > 0.0) .collect();
128
129 if distances.is_empty() {
130 return 0.0;
131 }
132
133 distances.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
135 let k = k.min(distances.len());
136
137 distances.iter().take(k).sum::<f64>() / k as f64
139 }
140}
141
142impl Default for NoveltyArchive {
143 fn default() -> Self {
144 Self::new(1000, 0.1)
145 }
146}
147
148pub struct NoveltySelection {
151 pub k: usize,
153 pub tournament_size: usize,
155 archive: Arc<RwLock<NoveltyArchive>>,
157}
158
159impl NoveltySelection {
160 pub fn new(k: usize, tournament_size: usize, archive: Arc<RwLock<NoveltyArchive>>) -> Self {
162 Self {
163 k,
164 tournament_size,
165 archive,
166 }
167 }
168
169 pub fn with_archive(archive: Arc<RwLock<NoveltyArchive>>) -> Self {
171 Self::new(15, 5, archive)
172 }
173
174 pub fn compute_novelty_scores<G: Genome>(&self, population: &Population<G>) -> Vec<f64> {
176 let archive = self.archive.read().unwrap();
177
178 let pop_behaviors: Vec<&BehaviorDescriptor> = population
180 .individuals
181 .iter()
182 .filter_map(|ind| ind.behavior.as_ref())
183 .collect();
184
185 population
187 .individuals
188 .iter()
189 .map(|ind| {
190 ind.behavior
191 .as_ref()
192 .map(|b| archive.compute_novelty(b, self.k, &pop_behaviors))
193 .unwrap_or(0.0)
194 })
195 .collect()
196 }
197
198 pub fn update_archive<G: Genome>(&self, population: &Population<G>) {
200 let novelty_scores = self.compute_novelty_scores(population);
201 let mut archive = self.archive.write().unwrap();
202
203 for (ind, novelty) in population.individuals.iter().zip(novelty_scores.iter()) {
204 if let Some(behavior) = &ind.behavior {
205 archive.add(behavior, *novelty);
206 }
207 }
208 }
209
210 pub fn archive(&self) -> Arc<RwLock<NoveltyArchive>> {
212 Arc::clone(&self.archive)
213 }
214}
215
216impl<G: Genome> Selection<G> for NoveltySelection {
217 fn select<'a, R: Rng>(
218 &self,
219 population: &'a Population<G>,
220 rng: &mut R,
221 ) -> &'a Individual<G> {
222 let novelty_scores = self.compute_novelty_scores(population);
223 let n = population.individuals.len();
224
225 let mut best_idx = 0;
226 let mut best_novelty = f64::NEG_INFINITY;
227
228 for _ in 0..self.tournament_size {
230 let idx = rng.gen_range(0..n);
231 let novelty = novelty_scores[idx];
232
233 if novelty > best_novelty {
234 best_novelty = novelty;
235 best_idx = idx;
236 }
237 }
238
239 &population.individuals[best_idx]
240 }
241}
242
243pub struct NoveltyFitnessSelection {
246 novelty: NoveltySelection,
248 fitness_weight: f64,
250}
251
252impl NoveltyFitnessSelection {
253 pub fn new(novelty: NoveltySelection, fitness_weight: f64) -> Self {
255 Self {
256 novelty,
257 fitness_weight: fitness_weight.clamp(0.0, 1.0),
258 }
259 }
260
261 pub fn compute_combined_scores<G: Genome>(&self, population: &Population<G>) -> Vec<f64> {
263 let novelty_scores = self.novelty.compute_novelty_scores(population);
264
265 let max_novelty = novelty_scores.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
267 let min_novelty = novelty_scores.iter().cloned().fold(f64::INFINITY, f64::min);
268 let novelty_range = max_novelty - min_novelty;
269
270 population
271 .individuals
272 .iter()
273 .zip(novelty_scores.iter())
274 .map(|(ind, &novelty)| {
275 let fitness = ind.fitness_value();
276 let norm_novelty = if novelty_range > 0.0 {
277 (novelty - min_novelty) / novelty_range
278 } else {
279 0.5
280 };
281
282 self.fitness_weight * fitness + (1.0 - self.fitness_weight) * norm_novelty
283 })
284 .collect()
285 }
286
287 pub fn archive(&self) -> Arc<RwLock<NoveltyArchive>> {
289 self.novelty.archive()
290 }
291
292 pub fn update_archive<G: Genome>(&self, population: &Population<G>) {
294 self.novelty.update_archive(population);
295 }
296}
297
298impl<G: Genome> Selection<G> for NoveltyFitnessSelection {
299 fn select<'a, R: Rng>(
300 &self,
301 population: &'a Population<G>,
302 rng: &mut R,
303 ) -> &'a Individual<G> {
304 let combined_scores = self.compute_combined_scores(population);
305 let n = population.individuals.len();
306
307 let mut best_idx = 0;
308 let mut best_score = f64::NEG_INFINITY;
309
310 for _ in 0..self.novelty.tournament_size {
312 let idx = rng.gen_range(0..n);
313 let score = combined_scores[idx];
314
315 if score > best_score {
316 best_score = score;
317 best_idx = idx;
318 }
319 }
320
321 &population.individuals[best_idx]
322 }
323}
324
325#[cfg(test)]
326mod tests {
327 use super::*;
328 use crate::fitness::FitnessValue;
329 use crate::population::PopulationConfig;
330 use rand::SeedableRng;
331 use rand_chacha::ChaCha8Rng;
332
333 #[derive(Clone)]
335 struct TestGenome {
336 value: f64,
337 }
338
339 impl Genome for TestGenome {
340 type Phenotype = f64;
341
342 fn random<R: Rng>(rng: &mut R) -> Self {
343 Self {
344 value: rng.gen_range(0.0..1.0),
345 }
346 }
347
348 fn mutate<R: Rng>(&mut self, rng: &mut R, _rate: f64) {
349 self.value = rng.gen_range(0.0..1.0);
350 }
351
352 fn crossover<R: Rng>(&self, other: &Self, _rng: &mut R) -> Self {
353 Self {
354 value: (self.value + other.value) / 2.0,
355 }
356 }
357
358 fn to_phenotype(&self) -> f64 {
359 self.value
360 }
361 }
362
363 #[test]
364 fn test_tournament_new() {
365 let tournament = Tournament::new(5);
366 assert_eq!(tournament.size, 5);
367 }
368
369 #[test]
370 fn test_tournament_selects_from_population() {
371 let mut rng = ChaCha8Rng::seed_from_u64(42);
372 let config = PopulationConfig {
373 size: 10,
374 elitism: 1,
375 };
376 let mut pop: Population<TestGenome> = Population::random(config, &mut rng);
377
378 for (i, ind) in pop.individuals.iter_mut().enumerate() {
380 ind.fitness = Some(FitnessValue::Single(i as f64 / 10.0));
381 }
382
383 let tournament = Tournament::new(3);
384 let selected = tournament.select(&pop, &mut rng);
385
386 assert!(selected.fitness.is_some());
388 }
389
390 #[test]
391 fn test_tournament_prefers_higher_fitness() {
392 let mut rng = ChaCha8Rng::seed_from_u64(42);
393 let config = PopulationConfig {
394 size: 10,
395 elitism: 1,
396 };
397 let mut pop: Population<TestGenome> = Population::random(config, &mut rng);
398
399 for (i, ind) in pop.individuals.iter_mut().enumerate() {
401 if i == 5 {
402 ind.fitness = Some(FitnessValue::Single(100.0));
403 } else {
404 ind.fitness = Some(FitnessValue::Single(0.0));
405 }
406 }
407
408 let tournament = Tournament::new(5); let mut high_fitness_count = 0;
411 for _ in 0..100 {
412 let selected = tournament.select(&pop, &mut rng);
413 if selected.fitness_value() > 50.0 {
414 high_fitness_count += 1;
415 }
416 }
417
418 assert!(high_fitness_count > 30, "Expected >30, got {}", high_fitness_count);
421 }
422
423 #[test]
424 fn test_tournament_size_one_is_random() {
425 let mut rng = ChaCha8Rng::seed_from_u64(42);
426 let config = PopulationConfig {
427 size: 10,
428 elitism: 1,
429 };
430 let mut pop: Population<TestGenome> = Population::random(config, &mut rng);
431
432 for (i, ind) in pop.individuals.iter_mut().enumerate() {
433 ind.fitness = Some(FitnessValue::Single(i as f64));
434 }
435
436 let tournament = Tournament::new(1);
438 let mut selections = std::collections::HashMap::new();
439
440 for _ in 0..1000 {
441 let selected = tournament.select(&pop, &mut rng);
442 let fitness = selected.fitness_value() as i32;
443 *selections.entry(fitness).or_insert(0) += 1;
444 }
445
446 assert!(selections.len() > 1);
448 }
449
450 #[test]
453 fn test_novelty_archive_new() {
454 let archive = NoveltyArchive::new(100, 0.5);
455 assert!(archive.is_empty());
456 assert_eq!(archive.len(), 0);
457 }
458
459 #[test]
460 fn test_novelty_archive_add() {
461 let mut archive = NoveltyArchive::new(100, 0.5);
462 let behavior = BehaviorDescriptor::new(vec![1.0, 2.0, 3.0]);
463
464 assert!(archive.add(&behavior, 0.6));
466 assert_eq!(archive.len(), 1);
467
468 let behavior2 = BehaviorDescriptor::new(vec![4.0, 5.0, 6.0]);
470 assert!(!archive.add(&behavior2, 0.3));
471 assert_eq!(archive.len(), 1);
472 }
473
474 #[test]
475 fn test_novelty_archive_force_add() {
476 let mut archive = NoveltyArchive::new(100, 0.5);
477 let behavior = BehaviorDescriptor::new(vec![1.0, 2.0, 3.0]);
478
479 archive.force_add(behavior);
480 assert_eq!(archive.len(), 1);
481 }
482
483 #[test]
484 fn test_novelty_archive_compute_novelty() {
485 let mut archive = NoveltyArchive::new(100, 0.0);
486
487 archive.force_add(BehaviorDescriptor::new(vec![0.0, 0.0]));
489 archive.force_add(BehaviorDescriptor::new(vec![1.0, 0.0]));
490 archive.force_add(BehaviorDescriptor::new(vec![0.0, 1.0]));
491
492 let close_behavior = BehaviorDescriptor::new(vec![0.1, 0.1]);
494 let novelty = archive.compute_novelty(&close_behavior, 2, &[]);
495 assert!(novelty < 1.0, "Close point should have low novelty");
496
497 let far_behavior = BehaviorDescriptor::new(vec![10.0, 10.0]);
499 let far_novelty = archive.compute_novelty(&far_behavior, 2, &[]);
500 assert!(far_novelty > novelty, "Far point should have higher novelty");
501 }
502
503 #[test]
504 fn test_novelty_selection_new() {
505 let archive = Arc::new(RwLock::new(NoveltyArchive::default()));
506 let selection = NoveltySelection::new(15, 5, archive);
507 assert_eq!(selection.k, 15);
508 assert_eq!(selection.tournament_size, 5);
509 }
510
511 #[test]
512 fn test_novelty_selection_with_archive() {
513 let archive = Arc::new(RwLock::new(NoveltyArchive::default()));
514 let selection = NoveltySelection::with_archive(archive);
515 assert_eq!(selection.k, 15);
516 assert_eq!(selection.tournament_size, 5);
517 }
518
519 #[test]
520 fn test_novelty_selection_select() {
521 let mut rng = ChaCha8Rng::seed_from_u64(42);
522 let config = PopulationConfig {
523 size: 10,
524 elitism: 1,
525 };
526 let mut pop: Population<TestGenome> = Population::random(config, &mut rng);
527
528 for (i, ind) in pop.individuals.iter_mut().enumerate() {
530 ind.fitness = Some(FitnessValue::Single(0.5));
531 if i == 5 {
532 ind.behavior = Some(BehaviorDescriptor::new(vec![100.0, 100.0]));
533 } else {
534 ind.behavior = Some(BehaviorDescriptor::new(vec![i as f64 * 0.1, i as f64 * 0.1]));
535 }
536 }
537
538 let archive = Arc::new(RwLock::new(NoveltyArchive::default()));
539 let selection = NoveltySelection::new(3, 5, archive);
540
541 let selected = selection.select(&pop, &mut rng);
543 assert!(selected.behavior.is_some());
544 }
545
546 #[test]
547 fn test_novelty_selection_prefers_novel() {
548 let mut rng = ChaCha8Rng::seed_from_u64(42);
549 let config = PopulationConfig {
550 size: 10,
551 elitism: 1,
552 };
553 let mut pop: Population<TestGenome> = Population::random(config, &mut rng);
554
555 for (i, ind) in pop.individuals.iter_mut().enumerate() {
557 ind.fitness = Some(FitnessValue::Single(0.5));
558 if i == 5 {
559 ind.behavior = Some(BehaviorDescriptor::new(vec![100.0, 100.0]));
561 } else {
562 ind.behavior = Some(BehaviorDescriptor::new(vec![i as f64 * 0.01, i as f64 * 0.01]));
564 }
565 }
566
567 let archive = Arc::new(RwLock::new(NoveltyArchive::default()));
568 let selection = NoveltySelection::new(3, 8, archive); let mut novel_count = 0;
572 for _ in 0..100 {
573 let selected = selection.select(&pop, &mut rng);
574 if let Some(behavior) = &selected.behavior {
575 if behavior.values[0] > 50.0 {
576 novel_count += 1;
577 }
578 }
579 }
580
581 assert!(novel_count > 30, "Expected >30, got {}", novel_count);
584 }
585
586 #[test]
587 fn test_novelty_fitness_selection() {
588 let mut rng = ChaCha8Rng::seed_from_u64(42);
589 let config = PopulationConfig {
590 size: 10,
591 elitism: 1,
592 };
593 let mut pop: Population<TestGenome> = Population::random(config, &mut rng);
594
595 for (i, ind) in pop.individuals.iter_mut().enumerate() {
597 ind.fitness = Some(FitnessValue::Single(i as f64 / 10.0));
598 ind.behavior = Some(BehaviorDescriptor::new(vec![i as f64, i as f64]));
599 }
600
601 let archive = Arc::new(RwLock::new(NoveltyArchive::default()));
602 let novelty = NoveltySelection::new(3, 5, archive);
603 let selection = NoveltyFitnessSelection::new(novelty, 0.5);
604
605 let selected = selection.select(&pop, &mut rng);
607 assert!(selected.fitness.is_some());
608 }
609
610 #[test]
611 fn test_novelty_archive_update() {
612 let mut rng = ChaCha8Rng::seed_from_u64(42);
613 let config = PopulationConfig {
614 size: 5,
615 elitism: 1,
616 };
617 let mut pop: Population<TestGenome> = Population::random(config, &mut rng);
618
619 for (i, ind) in pop.individuals.iter_mut().enumerate() {
621 ind.fitness = Some(FitnessValue::Single(0.5));
622 ind.behavior = Some(BehaviorDescriptor::new(vec![i as f64 * 10.0, i as f64 * 10.0]));
623 }
624
625 let archive = Arc::new(RwLock::new(NoveltyArchive::new(100, 0.0)));
626 let selection = NoveltySelection::new(3, 5, archive.clone());
627
628 selection.update_archive(&pop);
629
630 let archive_read = archive.read().unwrap();
632 assert!(archive_read.len() > 0);
633 }
634}