Skip to main content

radiate_selectors/
stochastic_sampling.rs

1use radiate_core::{Chromosome, Objective, Optimize, Population, Select, pareto, random_provider};
2
3#[derive(Debug, Clone)]
4pub struct StochasticUniversalSamplingSelector;
5
6impl StochasticUniversalSamplingSelector {
7    pub fn new() -> Self {
8        StochasticUniversalSamplingSelector
9    }
10}
11
12impl<C: Chromosome + Clone> Select<C> for StochasticUniversalSamplingSelector {
13    fn select(
14        &self,
15        population: &Population<C>,
16        objective: &Objective,
17        count: usize,
18    ) -> Population<C> {
19        let fitness_values = match objective {
20            Objective::Single(opt) => {
21                let scores = population
22                    .get_scores()
23                    .map(|score| score.as_f32())
24                    .collect::<Vec<f32>>();
25                let total = scores.iter().sum::<f32>();
26                let mut fitness_values =
27                    scores.iter().map(|&fit| fit / total).collect::<Vec<f32>>();
28
29                if let Optimize::Minimize = opt {
30                    fitness_values.reverse();
31                }
32
33                fitness_values
34            }
35            Objective::Multi(_) => {
36                let weights =
37                    pareto::weights(&population.get_scores().collect::<Vec<_>>(), objective);
38                let total_weights = weights.iter().sum::<f32>();
39                weights
40                    .iter()
41                    .map(|&fit| fit / total_weights)
42                    .collect::<Vec<f32>>()
43            }
44        };
45
46        let fitness_total = fitness_values.iter().sum::<f32>();
47        let point_distance = fitness_total / count as f32;
48        let start_point = random_provider::range(0.0..point_distance);
49
50        let mut pointers = Vec::with_capacity(count);
51        let mut current_point = start_point;
52
53        for _ in 0..count {
54            let mut index = 0;
55            let mut fitness_sum = fitness_values[index];
56            while fitness_sum < current_point && index < fitness_values.len() - 1 {
57                index += 1;
58                fitness_sum += fitness_values[index];
59            }
60            pointers.push(population[index].clone());
61            current_point += point_distance;
62        }
63
64        Population::new(pointers)
65    }
66}