1use super::*;
2use rand::distributions::Uniform;
3use rayon::prelude::*;
4
5#[derive(Debug, Clone, Copy, PartialEq)]
16pub enum SelectionStrategy {
17    Tournament(usize ),
21
22    RouletteWheel,
25
26    Boltzmann(f64 , f64 ),
30
31    Rank,
34
35    Linear(f64 ),
38
39    Elitist(f64 ),
42}
43
44impl SelectionStrategy {
45    pub fn create_selector<I: Individual, R: Rng>(&self) -> Box<dyn SelectionMechanism<I, R>> {
47        match self {
48            SelectionStrategy::Tournament(size) => Box::new(TournamentSelection::new(*size)),
49            SelectionStrategy::RouletteWheel => Box::new(RouletteWheelSelection::new()),
50            SelectionStrategy::Rank => Box::new(RankSelection::new()),
51            SelectionStrategy::Linear(bias) => Box::new(LinearSelection::new(*bias)),
52            SelectionStrategy::Elitist(ratio) => Box::new(ElitistSelection::new(*ratio)),
53            SelectionStrategy::Boltzmann(temperature, cooling_rate) => {
54                Box::new(BoltzmannSelection::new(*temperature, *cooling_rate))
55            }
56        }
57    }
58}
59
60pub trait SelectionMechanism<I, R>: Sync
63where
64    I: Individual,
65    R: Rng,
66{
67    fn prepare(&mut self, population: &Population<I>);
77
78    fn select<'a>(&self, population: &'a Population<I>, rng: &mut R) -> (usize, &'a I);
90
91    #[inline]
112    fn select_distinct<'a>(
113        &self,
114        population: &'a Population<I>,
115        rng: &mut R,
116        max_tries: usize,
117        excluded: &I,
118    ) -> Option<(usize, &'a I)>
119    where
120        I: Individual,
121        R: Rng,
122    {
123        for _ in 0..max_tries {
124            let (index, candidate) = self.select(population, rng);
125            if candidate != excluded {
126                return Some((index, candidate));
127            }
128        }
129        None
130    }
131}
132
133#[derive(Clone, Copy, PartialEq, Eq)]
140pub struct TournamentSelection {
141    size: usize,
142}
143
144impl TournamentSelection {
145    pub fn new(size: usize) -> Self {
150        assert!(size != 0);
151        Self { size }
152    }
153
154    pub fn size(&self) -> usize {
156        self.size
157    }
158
159    pub fn set_size(&mut self, size: usize) {
164        assert!(size != 0);
165        self.size = size
166    }
167}
168
169impl<I, R> SelectionMechanism<I, R> for TournamentSelection
170where
171    I: Individual,
172    R: Rng,
173{
174    #[inline]
175    fn prepare(&mut self, _population: &Population<I>) {}
176
177    #[inline]
178    fn select<'a>(&self, population: &'a Population<I>, rng: &mut R) -> (usize, &'a I) {
179        rng.sample_iter(Uniform::new(0, population.len()))
180            .take(self.size)
181            .max() .map(|i| (i, &population[i]))
183            .unwrap()
184    }
185}
186
187#[derive(Default, Clone, Copy, PartialEq)]
195pub struct RouletteWheelSelection {
196    total_fitness: f64,
197}
198
199impl RouletteWheelSelection {
200    pub fn new() -> Self {
202        Self { total_fitness: 0.0 }
203    }
204}
205
206impl<I, R> SelectionMechanism<I, R> for RouletteWheelSelection
207where
208    I: Individual,
209    R: Rng,
210{
211    #[inline]
212    fn prepare(&mut self, population: &Population<I>) {
213        self.total_fitness = population.par_iter().map(|i| i.fitness()).sum();
214    }
215
216    #[inline]
217    fn select<'a>(&self, population: &'a Population<I>, rng: &mut R) -> (usize, &'a I) {
218        let target_fitness = rng.gen::<f64>() * self.total_fitness;
219        let mut cumulative_fitness = 0.0;
220        for (index, individual) in population.iter().enumerate() {
221            cumulative_fitness += individual.fitness();
222            if cumulative_fitness >= target_fitness {
223                return (index, individual);
224            }
225        }
226        (population.len() - 1, population.last().unwrap())
227    }
228}
229
230#[derive(Clone, Copy, PartialEq)]
241pub struct BoltzmannSelection {
242    temperature: f64,
243    cooling_rate: f64,
244    total_scaled_fitness: f64,
245}
246
247impl BoltzmannSelection {
248    pub fn new(temperature: f64, cooling_rate: f64) -> Self {
253        assert!(cooling_rate != 0.0);
254        Self {
255            temperature,
256            cooling_rate,
257            total_scaled_fitness: 0.0,
258        }
259    }
260
261    pub fn temperature(&self) -> f64 {
263        self.temperature
264    }
265
266    pub fn set_temperature(&mut self, temperature: f64) {
268        self.temperature = temperature
269    }
270
271    pub fn cooling_rate(&self) -> f64 {
273        self.cooling_rate
274    }
275
276    pub fn set_cooling_rate(&mut self, cooling_rate: f64) {
281        assert!(cooling_rate != 0.0);
282        self.cooling_rate = cooling_rate
283    }
284
285    #[inline]
287    fn scaled_fitness<I: Individual>(&self, individual: &I) -> f64 {
288        (individual.fitness() / self.temperature).exp()
289    }
290}
291
292impl<I, R> SelectionMechanism<I, R> for BoltzmannSelection
293where
294    I: Individual,
295    R: Rng,
296{
297    #[inline]
298    fn prepare(&mut self, population: &Population<I>) {
299        self.temperature *= self.cooling_rate;
300        self.total_scaled_fitness = population
301            .par_iter()
302            .map(|i| self.scaled_fitness(i))
303            .sum();
304    }
305
306    #[inline]
307    fn select<'a>(&self, population: &'a Population<I>, rng: &mut R) -> (usize, &'a I) {
308        let target_fitness = rng.gen::<f64>() * self.total_scaled_fitness;
309        let mut cumulative_fitness = 0.0;
310        for (index, individual) in population.iter().enumerate() {
311            cumulative_fitness += self.scaled_fitness(individual);
312            if cumulative_fitness >= target_fitness {
313                return (index, individual);
314            }
315        }
316        (population.len() - 1, population.last().unwrap())
317    }
318}
319
320#[derive(Default, Clone, Copy, PartialEq, Eq)]
328pub struct RankSelection {}
329
330impl RankSelection {
331    pub fn new() -> Self {
333        Self {}
334    }
335}
336
337impl<I, R> SelectionMechanism<I, R> for RankSelection
338where
339    I: Individual,
340    R: Rng,
341{
342    #[inline]
343    fn prepare(&mut self, _population: &Population<I>) {
344        }
346
347    #[inline]
348    fn select<'a>(&self, population: &'a Population<I>, rng: &mut R) -> (usize, &'a I) {
349        let n = population.len();
350        let total_ranks = n * (n + 1) / 2;
351        let target_rank = rng.gen_range(1..=total_ranks);
352
353        let mut cumulative_rank = 0;
354        for (index, individual) in population.iter().enumerate() {
355            cumulative_rank += index + 1;
356            if cumulative_rank >= target_rank {
357                return (index, individual);
358            }
359        }
360        (0, population.first().unwrap())
361    }
362}
363
364pub struct ElitistSelection {
378    ratio: f64,
379}
380
381impl ElitistSelection {
382    pub fn new(ratio: f64) -> Self {
388        assert!((0.0..1.0).contains(&ratio));
389        Self { ratio }
390    }
391
392    pub fn ratio(&self) -> f64 {
394        self.ratio
395    }
396
397    pub fn set_ratio(&mut self, ratio: f64) {
403        assert!((0.0..1.0).contains(&ratio));
404        self.ratio = ratio
405    }
406}
407
408impl<I, R> SelectionMechanism<I, R> for ElitistSelection
409where
410    I: Individual,
411    R: Rng,
412{
413    #[inline]
414    fn prepare(&mut self, _population: &Population<I>) {
415        }
417
418    #[inline]
419    fn select<'a>(&self, population: &'a Population<I>, rng: &mut R) -> (usize, &'a I) {
420        let selection_bound = (population.len() as f64 * self.ratio).floor() as usize;
421        let index = rng.gen_range(selection_bound..population.len());
422        (index, &population[index])
423    }
424}
425
426#[derive(Clone, Copy, PartialEq)]
434pub struct LinearSelection {
435    bias: f64,
436}
437
438impl LinearSelection {
439    pub fn new(bias: f64) -> Self {
441        Self { bias }
442    }
443
444    pub fn bias(&self) -> f64 {
446        self.bias
447    }
448
449    pub fn set_bias(&mut self, bias: f64) {
451        self.bias = bias
452    }
453}
454
455impl<I, R> SelectionMechanism<I, R> for LinearSelection
456where
457    I: Individual,
458    R: Rng,
459{
460    #[inline]
461    fn prepare(&mut self, _population: &Population<I>) {
462        }
464
465    #[inline]
466    fn select<'a>(&self, population: &'a Population<I>, rng: &mut R) -> (usize, &'a I) {
467        let index = population.len()
468            - ((population.len() as f64)
469                * (self.bias
470                    - ((self.bias * self.bias - 4.0 * (self.bias - 1.0) * rng.gen::<f64>())
471                        .sqrt()))
472                / 2.0
473                / (self.bias - 1.0)) as usize
474            - 1;
475        (index, &population[index])
476    }
477}
478
479#[cfg(test)]
480mod tests {
481    use super::*;
482
483    #[derive(Debug, Clone, Copy, PartialEq)]
484    struct MockIndividual {
485        fitness: f64,
486    }
487
488    impl Individual for MockIndividual {
489        fn fitness(&self) -> f64 {
490            self.fitness
491        }
492    }
493
494    const POPULATION_SIZE: usize = 10;
495    const ITERATIONS: usize = 100000;
496    const CONFIDENCE: f64 = 0.01;
497
498    fn create_test_population() -> Population<MockIndividual> {
499        (0..POPULATION_SIZE)
500            .map(|i| MockIndividual {
501                fitness: i as f64 / POPULATION_SIZE as f64,
502            })
503            .collect::<Vec<_>>()
504            .try_into()
505            .unwrap()
506    }
507
508    fn collect_samples<S>(strategy: &mut S) -> (Vec<MockIndividual>, Vec<u32>)
509    where
510        S: SelectionMechanism<MockIndividual, SmallRng>,
511    {
512        let mut rng = SmallRng::seed_from_u64(0);
513        let population = create_test_population();
514        let mut counts = vec![0; POPULATION_SIZE];
515
516        strategy.prepare(&population);
517        for _ in 0..ITERATIONS {
518            let (index, _) = strategy.select(&population, &mut rng);
519            counts[index] += 1;
520        }
521        (population.into(), counts)
522    }
523
524    #[test]
525    fn test_tournament_selection_probabilities() {
526        let mut tournament = TournamentSelection::new(2);
527        let (population, observed_counts) = collect_samples(&mut tournament);
528
529        for (index, count) in observed_counts.into_iter().enumerate() {
530            let observed_probability = count as f64 / ITERATIONS as f64;
531            let expected_probability = (1.0
532                - (index as f64 / population.len() as f64).powi(tournament.size as i32))
533                - (1.0
534                    - ((index as f64 + 1.0) / population.len() as f64)
535                        .powi(tournament.size as i32));
536
537            assert!((observed_probability - expected_probability).abs() < CONFIDENCE);
538        }
539    }
540
541    #[test]
542    fn test_roulette_wheel_selection_probabilities() {
543        let mut roulette_wheel = RouletteWheelSelection::new();
544        let (population, observed_counts) = collect_samples(&mut roulette_wheel);
545
546        let total_fitness = population.iter().map(|i| i.fitness).sum::<f64>();
547        for (index, count) in observed_counts.into_iter().enumerate() {
548            let observed_probability = count as f64 / ITERATIONS as f64;
549            let expected_probability = population[index].fitness / total_fitness;
550
551            assert!((observed_probability - expected_probability).abs() < CONFIDENCE);
552        }
553    }
554
555    #[test]
556    fn test_rank_selection_probabilities() {
557        let mut rank = RankSelection::new();
558        let (population, observed_counts) = collect_samples(&mut rank);
559
560        let total_ranks = population.len() * (population.len() + 1) / 2;
561        for (index, count) in observed_counts.into_iter().enumerate() {
562            let observed_probability = count as f64 / ITERATIONS as f64;
563            let expected_probability = (index + 1) as f64 / total_ranks as f64;
564
565            assert!((observed_probability - expected_probability).abs() < CONFIDENCE);
566        }
567    }
568
569    #[test]
570    fn test_elitist_selection_probabilities() {
571        let mut elitist = ElitistSelection::new(0.7);
572        let (population, observed_counts) = collect_samples(&mut elitist);
573
574        let selection_bound = (population.len() as f64 * elitist.ratio).floor() as usize;
575        for (index, count) in observed_counts.into_iter().enumerate() {
576            let observed_probability = count as f64 / ITERATIONS as f64;
577            let expected_probability = if index >= selection_bound {
578                1.0 / (population.len() - selection_bound) as f64
579            } else {
580                0.0
581            };
582
583            assert!((observed_probability - expected_probability).abs() < CONFIDENCE);
584        }
585    }
586
587    #[test]
588    fn test_boltzmann_selection_probabilities() {
589        let mut boltzmann = BoltzmannSelection::new(5.0, 1.0);
591        let (population, observed_counts) = collect_samples(&mut boltzmann);
592
593        let total_scaled_fitness = population
594            .iter()
595            .map(|i| (i.fitness / boltzmann.temperature).exp())
596            .sum::<f64>();
597        for (index, count) in observed_counts.into_iter().enumerate() {
598            let observed_probability = count as f64 / ITERATIONS as f64;
599            let expected_probability =
600                (population[index].fitness / boltzmann.temperature).exp() / total_scaled_fitness;
601
602            assert!((observed_probability - expected_probability).abs() < CONFIDENCE);
603        }
604    }
605
606    #[test]
607    fn test_linear_selection_probabilities() {
608        let mut linear = LinearSelection::new(1.25);
609        let (_, observed_counts) = collect_samples(&mut linear);
610
611        let mut last_observed_probability: Option<f64> = None;
612        for count in observed_counts {
613            let observed_probability = count as f64 / ITERATIONS as f64;
614            if let Some(previous_observed_probability) = last_observed_probability {
615                assert!(observed_probability + CONFIDENCE > previous_observed_probability);
616            }
617            last_observed_probability = Some(observed_probability);
618        }
619    }
620}