Skip to main content

red_queen_core/
selection.rs

1//! Selection operators.
2
3use crate::genome::{BehaviorDescriptor, Genome};
4use crate::population::{Individual, Population};
5use rand::Rng;
6use std::sync::{Arc, RwLock};
7
8/// Trait for selection operators.
9pub trait Selection<G: Genome>: Send + Sync {
10    /// Select one parent from the population.
11    fn select<'a, R: Rng>(
12        &self,
13        population: &'a Population<G>,
14        rng: &mut R,
15    ) -> &'a Individual<G>;
16}
17
18/// Tournament selection.
19pub struct Tournament {
20    /// Tournament size.
21    pub size: usize,
22}
23
24impl Tournament {
25    /// Create a new tournament selector.
26    pub fn new(size: usize) -> Self {
27        Self { size }
28    }
29}
30
31impl<G: Genome> Selection<G> for Tournament {
32    fn select<'a, R: Rng>(
33        &self,
34        population: &'a Population<G>,
35        rng: &mut R,
36    ) -> &'a Individual<G> {
37        let n = population.individuals.len();
38        let mut best: Option<&Individual<G>> = None;
39        let mut best_fitness = f64::NEG_INFINITY;
40
41        for _ in 0..self.size {
42            let idx = rng.gen_range(0..n);
43            let ind = &population.individuals[idx];
44            let fitness = ind.fitness_value();
45
46            if fitness > best_fitness {
47                best_fitness = fitness;
48                best = Some(ind);
49            }
50        }
51
52        best.unwrap_or(&population.individuals[0])
53    }
54}
55
56/// Archive for tracking explored behaviors (for novelty computation).
57#[derive(Clone)]
58pub struct NoveltyArchive {
59    /// Stored behavior descriptors.
60    behaviors: Vec<BehaviorDescriptor>,
61    /// Maximum archive size.
62    max_size: usize,
63    /// Minimum novelty threshold for adding to archive.
64    add_threshold: f64,
65}
66
67impl NoveltyArchive {
68    /// Create a new novelty archive.
69    pub fn new(max_size: usize, add_threshold: f64) -> Self {
70        Self {
71            behaviors: Vec::new(),
72            max_size,
73            add_threshold,
74        }
75    }
76
77    /// Add a behavior to the archive if it's novel enough.
78    pub fn add(&mut self, behavior: &BehaviorDescriptor, novelty: f64) -> bool {
79        if novelty >= self.add_threshold && self.behaviors.len() < self.max_size {
80            self.behaviors.push(behavior.clone());
81            true
82        } else {
83            false
84        }
85    }
86
87    /// Force add a behavior (ignoring threshold).
88    pub fn force_add(&mut self, behavior: BehaviorDescriptor) {
89        if self.behaviors.len() < self.max_size {
90            self.behaviors.push(behavior);
91        }
92    }
93
94    /// Get all behaviors in the archive.
95    pub fn behaviors(&self) -> &[BehaviorDescriptor] {
96        &self.behaviors
97    }
98
99    /// Number of behaviors in archive.
100    pub fn len(&self) -> usize {
101        self.behaviors.len()
102    }
103
104    /// Check if archive is empty.
105    pub fn is_empty(&self) -> bool {
106        self.behaviors.is_empty()
107    }
108
109    /// Compute novelty score for a behavior.
110    /// Novelty = average distance to k nearest neighbors.
111    pub fn compute_novelty(&self, behavior: &BehaviorDescriptor, k: usize, population_behaviors: &[&BehaviorDescriptor]) -> f64 {
112        // Combine archive behaviors with current population
113        let all_behaviors: Vec<&BehaviorDescriptor> = self.behaviors
114            .iter()
115            .chain(population_behaviors.iter().copied())
116            .collect();
117
118        if all_behaviors.is_empty() {
119            return f64::MAX; // Maximum novelty if no comparisons
120        }
121
122        // Compute distances to all other behaviors
123        let mut distances: Vec<f64> = all_behaviors
124            .iter()
125            .map(|other| behavior.distance(other))
126            .filter(|d| *d > 0.0) // Exclude self
127            .collect();
128
129        if distances.is_empty() {
130            return 0.0;
131        }
132
133        // Sort and take k nearest
134        distances.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
135        let k = k.min(distances.len());
136
137        // Average distance to k nearest neighbors
138        distances.iter().take(k).sum::<f64>() / k as f64
139    }
140}
141
142impl Default for NoveltyArchive {
143    fn default() -> Self {
144        Self::new(1000, 0.1)
145    }
146}
147
148/// Novelty-based selection.
149/// Selects individuals based on how different they are from others.
150pub struct NoveltySelection {
151    /// Number of nearest neighbors for novelty computation.
152    pub k: usize,
153    /// Tournament size for selection.
154    pub tournament_size: usize,
155    /// Shared novelty archive.
156    archive: Arc<RwLock<NoveltyArchive>>,
157}
158
159impl NoveltySelection {
160    /// Create a new novelty selector.
161    pub fn new(k: usize, tournament_size: usize, archive: Arc<RwLock<NoveltyArchive>>) -> Self {
162        Self {
163            k,
164            tournament_size,
165            archive,
166        }
167    }
168
169    /// Create with default parameters (k=15, tournament=5).
170    pub fn with_archive(archive: Arc<RwLock<NoveltyArchive>>) -> Self {
171        Self::new(15, 5, archive)
172    }
173
174    /// Compute novelty scores for all individuals in population.
175    pub fn compute_novelty_scores<G: Genome>(&self, population: &Population<G>) -> Vec<f64> {
176        let archive = self.archive.read().unwrap();
177
178        // Collect behaviors from population
179        let pop_behaviors: Vec<&BehaviorDescriptor> = population
180            .individuals
181            .iter()
182            .filter_map(|ind| ind.behavior.as_ref())
183            .collect();
184
185        // Compute novelty for each individual
186        population
187            .individuals
188            .iter()
189            .map(|ind| {
190                ind.behavior
191                    .as_ref()
192                    .map(|b| archive.compute_novelty(b, self.k, &pop_behaviors))
193                    .unwrap_or(0.0)
194            })
195            .collect()
196    }
197
198    /// Update the archive with novel individuals from the population.
199    pub fn update_archive<G: Genome>(&self, population: &Population<G>) {
200        let novelty_scores = self.compute_novelty_scores(population);
201        let mut archive = self.archive.write().unwrap();
202
203        for (ind, novelty) in population.individuals.iter().zip(novelty_scores.iter()) {
204            if let Some(behavior) = &ind.behavior {
205                archive.add(behavior, *novelty);
206            }
207        }
208    }
209
210    /// Get a reference to the archive.
211    pub fn archive(&self) -> Arc<RwLock<NoveltyArchive>> {
212        Arc::clone(&self.archive)
213    }
214}
215
216impl<G: Genome> Selection<G> for NoveltySelection {
217    fn select<'a, R: Rng>(
218        &self,
219        population: &'a Population<G>,
220        rng: &mut R,
221    ) -> &'a Individual<G> {
222        let novelty_scores = self.compute_novelty_scores(population);
223        let n = population.individuals.len();
224
225        let mut best_idx = 0;
226        let mut best_novelty = f64::NEG_INFINITY;
227
228        // Tournament selection based on novelty
229        for _ in 0..self.tournament_size {
230            let idx = rng.gen_range(0..n);
231            let novelty = novelty_scores[idx];
232
233            if novelty > best_novelty {
234                best_novelty = novelty;
235                best_idx = idx;
236            }
237        }
238
239        &population.individuals[best_idx]
240    }
241}
242
243/// Combined fitness and novelty selection.
244/// Uses a weighted combination of fitness and novelty scores.
245pub struct NoveltyFitnessSelection {
246    /// Novelty selector.
247    novelty: NoveltySelection,
248    /// Weight for fitness (0.0 = pure novelty, 1.0 = pure fitness).
249    fitness_weight: f64,
250}
251
252impl NoveltyFitnessSelection {
253    /// Create a new combined selector.
254    pub fn new(novelty: NoveltySelection, fitness_weight: f64) -> Self {
255        Self {
256            novelty,
257            fitness_weight: fitness_weight.clamp(0.0, 1.0),
258        }
259    }
260
261    /// Compute combined scores (fitness + novelty).
262    pub fn compute_combined_scores<G: Genome>(&self, population: &Population<G>) -> Vec<f64> {
263        let novelty_scores = self.novelty.compute_novelty_scores(population);
264
265        // Normalize novelty scores
266        let max_novelty = novelty_scores.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
267        let min_novelty = novelty_scores.iter().cloned().fold(f64::INFINITY, f64::min);
268        let novelty_range = max_novelty - min_novelty;
269
270        population
271            .individuals
272            .iter()
273            .zip(novelty_scores.iter())
274            .map(|(ind, &novelty)| {
275                let fitness = ind.fitness_value();
276                let norm_novelty = if novelty_range > 0.0 {
277                    (novelty - min_novelty) / novelty_range
278                } else {
279                    0.5
280                };
281
282                self.fitness_weight * fitness + (1.0 - self.fitness_weight) * norm_novelty
283            })
284            .collect()
285    }
286
287    /// Get a reference to the novelty archive.
288    pub fn archive(&self) -> Arc<RwLock<NoveltyArchive>> {
289        self.novelty.archive()
290    }
291
292    /// Update the archive with novel individuals.
293    pub fn update_archive<G: Genome>(&self, population: &Population<G>) {
294        self.novelty.update_archive(population);
295    }
296}
297
298impl<G: Genome> Selection<G> for NoveltyFitnessSelection {
299    fn select<'a, R: Rng>(
300        &self,
301        population: &'a Population<G>,
302        rng: &mut R,
303    ) -> &'a Individual<G> {
304        let combined_scores = self.compute_combined_scores(population);
305        let n = population.individuals.len();
306
307        let mut best_idx = 0;
308        let mut best_score = f64::NEG_INFINITY;
309
310        // Tournament selection based on combined score
311        for _ in 0..self.novelty.tournament_size {
312            let idx = rng.gen_range(0..n);
313            let score = combined_scores[idx];
314
315            if score > best_score {
316                best_score = score;
317                best_idx = idx;
318            }
319        }
320
321        &population.individuals[best_idx]
322    }
323}
324
325#[cfg(test)]
326mod tests {
327    use super::*;
328    use crate::fitness::FitnessValue;
329    use crate::population::PopulationConfig;
330    use rand::SeedableRng;
331    use rand_chacha::ChaCha8Rng;
332
333    // Simple test genome
334    #[derive(Clone)]
335    struct TestGenome {
336        value: f64,
337    }
338
339    impl Genome for TestGenome {
340        type Phenotype = f64;
341
342        fn random<R: Rng>(rng: &mut R) -> Self {
343            Self {
344                value: rng.gen_range(0.0..1.0),
345            }
346        }
347
348        fn mutate<R: Rng>(&mut self, rng: &mut R, _rate: f64) {
349            self.value = rng.gen_range(0.0..1.0);
350        }
351
352        fn crossover<R: Rng>(&self, other: &Self, _rng: &mut R) -> Self {
353            Self {
354                value: (self.value + other.value) / 2.0,
355            }
356        }
357
358        fn to_phenotype(&self) -> f64 {
359            self.value
360        }
361    }
362
363    #[test]
364    fn test_tournament_new() {
365        let tournament = Tournament::new(5);
366        assert_eq!(tournament.size, 5);
367    }
368
369    #[test]
370    fn test_tournament_selects_from_population() {
371        let mut rng = ChaCha8Rng::seed_from_u64(42);
372        let config = PopulationConfig {
373            size: 10,
374            elitism: 1,
375        };
376        let mut pop: Population<TestGenome> = Population::random(config, &mut rng);
377
378        // Assign fitness values
379        for (i, ind) in pop.individuals.iter_mut().enumerate() {
380            ind.fitness = Some(FitnessValue::Single(i as f64 / 10.0));
381        }
382
383        let tournament = Tournament::new(3);
384        let selected = tournament.select(&pop, &mut rng);
385
386        // Should select one of the individuals
387        assert!(selected.fitness.is_some());
388    }
389
390    #[test]
391    fn test_tournament_prefers_higher_fitness() {
392        let mut rng = ChaCha8Rng::seed_from_u64(42);
393        let config = PopulationConfig {
394            size: 10,
395            elitism: 1,
396        };
397        let mut pop: Population<TestGenome> = Population::random(config, &mut rng);
398
399        // Give one individual very high fitness
400        for (i, ind) in pop.individuals.iter_mut().enumerate() {
401            if i == 5 {
402                ind.fitness = Some(FitnessValue::Single(100.0));
403            } else {
404                ind.fitness = Some(FitnessValue::Single(0.0));
405            }
406        }
407
408        // With large tournament, should frequently select the best
409        let tournament = Tournament::new(5); // Half population
410        let mut high_fitness_count = 0;
411        for _ in 0..100 {
412            let selected = tournament.select(&pop, &mut rng);
413            if selected.fitness_value() > 50.0 {
414                high_fitness_count += 1;
415            }
416        }
417
418        // With tournament size 5 out of 10, probability of including the best
419        // in each tournament is 1 - (9/10)^5 ≈ 0.41, so we expect ~40+ out of 100
420        assert!(high_fitness_count > 30, "Expected >30, got {}", high_fitness_count);
421    }
422
423    #[test]
424    fn test_tournament_size_one_is_random() {
425        let mut rng = ChaCha8Rng::seed_from_u64(42);
426        let config = PopulationConfig {
427            size: 10,
428            elitism: 1,
429        };
430        let mut pop: Population<TestGenome> = Population::random(config, &mut rng);
431
432        for (i, ind) in pop.individuals.iter_mut().enumerate() {
433            ind.fitness = Some(FitnessValue::Single(i as f64));
434        }
435
436        // Tournament size 1 = random selection
437        let tournament = Tournament::new(1);
438        let mut selections = std::collections::HashMap::new();
439
440        for _ in 0..1000 {
441            let selected = tournament.select(&pop, &mut rng);
442            let fitness = selected.fitness_value() as i32;
443            *selections.entry(fitness).or_insert(0) += 1;
444        }
445
446        // Should have selected multiple different individuals
447        assert!(selections.len() > 1);
448    }
449
450    // Novelty selection tests
451
452    #[test]
453    fn test_novelty_archive_new() {
454        let archive = NoveltyArchive::new(100, 0.5);
455        assert!(archive.is_empty());
456        assert_eq!(archive.len(), 0);
457    }
458
459    #[test]
460    fn test_novelty_archive_add() {
461        let mut archive = NoveltyArchive::new(100, 0.5);
462        let behavior = BehaviorDescriptor::new(vec![1.0, 2.0, 3.0]);
463
464        // Should add if novelty >= threshold
465        assert!(archive.add(&behavior, 0.6));
466        assert_eq!(archive.len(), 1);
467
468        // Should not add if novelty < threshold
469        let behavior2 = BehaviorDescriptor::new(vec![4.0, 5.0, 6.0]);
470        assert!(!archive.add(&behavior2, 0.3));
471        assert_eq!(archive.len(), 1);
472    }
473
474    #[test]
475    fn test_novelty_archive_force_add() {
476        let mut archive = NoveltyArchive::new(100, 0.5);
477        let behavior = BehaviorDescriptor::new(vec![1.0, 2.0, 3.0]);
478
479        archive.force_add(behavior);
480        assert_eq!(archive.len(), 1);
481    }
482
483    #[test]
484    fn test_novelty_archive_compute_novelty() {
485        let mut archive = NoveltyArchive::new(100, 0.0);
486
487        // Add some behaviors
488        archive.force_add(BehaviorDescriptor::new(vec![0.0, 0.0]));
489        archive.force_add(BehaviorDescriptor::new(vec![1.0, 0.0]));
490        archive.force_add(BehaviorDescriptor::new(vec![0.0, 1.0]));
491
492        // Test novelty of a point close to origin
493        let close_behavior = BehaviorDescriptor::new(vec![0.1, 0.1]);
494        let novelty = archive.compute_novelty(&close_behavior, 2, &[]);
495        assert!(novelty < 1.0, "Close point should have low novelty");
496
497        // Test novelty of a point far from all others
498        let far_behavior = BehaviorDescriptor::new(vec![10.0, 10.0]);
499        let far_novelty = archive.compute_novelty(&far_behavior, 2, &[]);
500        assert!(far_novelty > novelty, "Far point should have higher novelty");
501    }
502
503    #[test]
504    fn test_novelty_selection_new() {
505        let archive = Arc::new(RwLock::new(NoveltyArchive::default()));
506        let selection = NoveltySelection::new(15, 5, archive);
507        assert_eq!(selection.k, 15);
508        assert_eq!(selection.tournament_size, 5);
509    }
510
511    #[test]
512    fn test_novelty_selection_with_archive() {
513        let archive = Arc::new(RwLock::new(NoveltyArchive::default()));
514        let selection = NoveltySelection::with_archive(archive);
515        assert_eq!(selection.k, 15);
516        assert_eq!(selection.tournament_size, 5);
517    }
518
519    #[test]
520    fn test_novelty_selection_select() {
521        let mut rng = ChaCha8Rng::seed_from_u64(42);
522        let config = PopulationConfig {
523            size: 10,
524            elitism: 1,
525        };
526        let mut pop: Population<TestGenome> = Population::random(config, &mut rng);
527
528        // Assign behaviors - make one very different
529        for (i, ind) in pop.individuals.iter_mut().enumerate() {
530            ind.fitness = Some(FitnessValue::Single(0.5));
531            if i == 5 {
532                ind.behavior = Some(BehaviorDescriptor::new(vec![100.0, 100.0]));
533            } else {
534                ind.behavior = Some(BehaviorDescriptor::new(vec![i as f64 * 0.1, i as f64 * 0.1]));
535            }
536        }
537
538        let archive = Arc::new(RwLock::new(NoveltyArchive::default()));
539        let selection = NoveltySelection::new(3, 5, archive);
540
541        // Selection should work without panicking
542        let selected = selection.select(&pop, &mut rng);
543        assert!(selected.behavior.is_some());
544    }
545
546    #[test]
547    fn test_novelty_selection_prefers_novel() {
548        let mut rng = ChaCha8Rng::seed_from_u64(42);
549        let config = PopulationConfig {
550            size: 10,
551            elitism: 1,
552        };
553        let mut pop: Population<TestGenome> = Population::random(config, &mut rng);
554
555        // Individuals clustered around origin with slight variation, except one outlier
556        for (i, ind) in pop.individuals.iter_mut().enumerate() {
557            ind.fitness = Some(FitnessValue::Single(0.5));
558            if i == 5 {
559                // Far outlier
560                ind.behavior = Some(BehaviorDescriptor::new(vec![100.0, 100.0]));
561            } else {
562                // Slight variation around origin
563                ind.behavior = Some(BehaviorDescriptor::new(vec![i as f64 * 0.01, i as f64 * 0.01]));
564            }
565        }
566
567        let archive = Arc::new(RwLock::new(NoveltyArchive::default()));
568        let selection = NoveltySelection::new(3, 8, archive);  // Large tournament
569
570        // Count how often the novel individual is selected
571        let mut novel_count = 0;
572        for _ in 0..100 {
573            let selected = selection.select(&pop, &mut rng);
574            if let Some(behavior) = &selected.behavior {
575                if behavior.values[0] > 50.0 {
576                    novel_count += 1;
577                }
578            }
579        }
580
581        // Should frequently select the most novel individual
582        // With tournament size 8 out of 10, probability of including outlier is very high
583        assert!(novel_count > 30, "Expected >30, got {}", novel_count);
584    }
585
586    #[test]
587    fn test_novelty_fitness_selection() {
588        let mut rng = ChaCha8Rng::seed_from_u64(42);
589        let config = PopulationConfig {
590            size: 10,
591            elitism: 1,
592        };
593        let mut pop: Population<TestGenome> = Population::random(config, &mut rng);
594
595        // Assign fitness and behaviors
596        for (i, ind) in pop.individuals.iter_mut().enumerate() {
597            ind.fitness = Some(FitnessValue::Single(i as f64 / 10.0));
598            ind.behavior = Some(BehaviorDescriptor::new(vec![i as f64, i as f64]));
599        }
600
601        let archive = Arc::new(RwLock::new(NoveltyArchive::default()));
602        let novelty = NoveltySelection::new(3, 5, archive);
603        let selection = NoveltyFitnessSelection::new(novelty, 0.5);
604
605        // Selection should work
606        let selected = selection.select(&pop, &mut rng);
607        assert!(selected.fitness.is_some());
608    }
609
610    #[test]
611    fn test_novelty_archive_update() {
612        let mut rng = ChaCha8Rng::seed_from_u64(42);
613        let config = PopulationConfig {
614            size: 5,
615            elitism: 1,
616        };
617        let mut pop: Population<TestGenome> = Population::random(config, &mut rng);
618
619        // Assign behaviors
620        for (i, ind) in pop.individuals.iter_mut().enumerate() {
621            ind.fitness = Some(FitnessValue::Single(0.5));
622            ind.behavior = Some(BehaviorDescriptor::new(vec![i as f64 * 10.0, i as f64 * 10.0]));
623        }
624
625        let archive = Arc::new(RwLock::new(NoveltyArchive::new(100, 0.0)));
626        let selection = NoveltySelection::new(3, 5, archive.clone());
627
628        selection.update_archive(&pop);
629
630        // Archive should have some behaviors now
631        let archive_read = archive.read().unwrap();
632        assert!(archive_read.len() > 0);
633    }
634}