1use super::agent::{Agent, crossover};
16use super::population::Population;
17use std::hash::Hash;
18use rand::{
19 distributions::{Distribution, Standard},
20 Rng,
21};
22use std::marker::{Send, PhantomData};
23use std::collections::BTreeMap;
24use super::fitness::{Score, ScoreProvider};
25
26
27#[derive(Clone, Copy)]
28pub enum OperationType {
29 Mutate,
30 Crossover,
31 Cull
32}
33
34#[derive(Clone, Copy)]
35pub enum SelectionType {
36 RandomAny,
37 HighestScore,
38 LowestScore
39}
40
41#[derive(Clone, Copy)]
43pub struct Selection {
44 selection_type: SelectionType,
45 proportion: f64,
46 preferred_minimum: usize
47}
48
49impl Selection {
50 pub fn with_values(selection_type: SelectionType, proportion: f64, preferred_minimum: usize) -> Self {
51 Self {
52 selection_type: selection_type,
53 proportion: proportion,
54 preferred_minimum: preferred_minimum
55 }
56 }
57
58 pub fn new(selection_type: SelectionType, proportion: f64) -> Self {
59 Self {
60 selection_type: selection_type,
61 proportion: proportion,
62 preferred_minimum: 1
63 }
64 }
65
66 pub fn selection_type(&self) -> SelectionType {
67 self.selection_type
68 }
69
70 pub fn proportion(&self) -> f64 {
71 self.proportion
72 }
73
74 pub fn preferred_minimum(&self) -> usize {
75 self.preferred_minimum
76 }
77
78 pub fn agents <'a, Gene> (&self, population: &'a Population<Gene>) -> BTreeMap<Score, &'a Agent<Gene>>
79 where
80 Gene: Clone
81 {
82 match self.selection_type {
83 SelectionType::RandomAny => get_random_subset(population.get_agents(), self.proportion, self.preferred_minimum),
84 SelectionType::HighestScore => get_highest_scored_agents(population.get_agents(), self.proportion, self.preferred_minimum),
85 SelectionType::LowestScore => get_lowest_scored_agents(population.get_agents(), self.proportion, self.preferred_minimum)
86 }
87 }
88
89 pub fn count <Gene> (&self, population: &Population<Gene>) -> usize {
90 rate_to_number(population.len(), self.proportion, self.preferred_minimum)
91 }
92}
93
94#[derive(Clone)]
96pub struct Operation <Gene, Data>
97where
98Standard: Distribution<Gene>,
99Gene: Clone + Hash + Send + 'static,
100Data: Clone + Send + 'static
101{
102 selection: Selection,
103 operation_type: OperationType,
104 gene: PhantomData<Gene>,
105 data: PhantomData<Data>
106}
107
108impl <Gene, Data> Operation <Gene, Data>
109where
110Standard: Distribution<Gene>,
111Gene: Clone + Hash + Send + 'static,
112Data: Clone + Send + 'static
113{
114 pub fn with_values(
115 selection: Selection,
116 operation_type: OperationType
117 ) -> Self {
118 Self {
119 selection: selection,
120 operation_type: operation_type,
121 gene: PhantomData,
122 data: PhantomData
123 }
124 }
125
126 pub fn new(
127 operation_type: OperationType,
128 selection: Selection
129 ) -> Self {
130 Self {
131 selection: selection,
132 operation_type: operation_type,
133 gene: PhantomData,
134 data: PhantomData
135 }
136 }
137
138 pub fn run (&self, population: Population<Gene>, data: &Data, score_provider: &mut ScoreProvider<Gene, Data>) -> Population<Gene>
139 {
140 match self.operation_type {
141 OperationType::Mutate => mutate_agents(population, self.selection, data, score_provider),
142 OperationType::Crossover => crossover_agents(population, self.selection, data, score_provider),
143 OperationType::Cull => cull_agents(population, self.selection)
144 }
145 }
146}
147
148fn mutate_agents<Gene, Data>(
149 mut population: Population<Gene>,
150 selection: Selection,
151 data: &Data,
152 score_provider: &mut ScoreProvider<Gene, Data>
153) -> Population<Gene>
154where
155Standard: Distribution<Gene>,
156Gene: Clone + Hash + Send + 'static,
157Data: Clone + Send + 'static
158{
159 let children = get_mutated_agents(selection.agents(&population));
160 let children = score_provider.evaluate_scores(children, data).unwrap();
161 let mut rng = rand::thread_rng();
162 for agent in children {
163 let score_index = score_provider.get_score(&agent, data, &mut rng).unwrap();
164 population.insert(score_index, agent);
165 }
166
167 population
168}
169
170fn crossover_agents<Gene, Data>(
171 mut population: Population<Gene>,
172 selection: Selection,
173 data: &Data,
174 score_provider: &mut ScoreProvider<Gene, Data>
175) -> Population<Gene>
176where
177Standard: Distribution<Gene>,
178Gene: Clone + Hash + Send + 'static,
179Data: Clone + Send + 'static
180{
181 let pairs = create_random_pairs(
182 selection.agents(&population)
183 );
184
185 let children = create_children_from_crossover(pairs, data, score_provider);
186 for (score_index, agent) in children {
187 population.insert(score_index, agent);
188 }
189
190 population
191}
192
193fn cull_agents<Gene>(
194 mut population: Population<Gene>,
195 selection: Selection,
196) -> Population<Gene>
197{
198 let keys: Vec<Score> = population.get_agents().keys().map(|k| *k).collect();
199 let cull_number = selection.count(&population);
200 if cull_number >= keys.len() {
201 return population;
202 }
203
204 match selection.selection_type() {
205 SelectionType::LowestScore => population.cull_all_below(keys[cull_number]),
206 SelectionType::HighestScore => population.cull_all_above(keys[cull_number]),
207 SelectionType::RandomAny => panic!("RandomAny selection not yet implemented for cull agents")
208 };
209 population
210}
211
212fn get_mutated_agents<Gene>(
213 agents: BTreeMap<Score, &Agent<Gene>>,
214) -> Vec<Agent<Gene>>
215where Standard: Distribution<Gene>,
216Gene: Clone + Hash + Send
217{
218 let mut children = Vec::new();
219 for (_, mut agent) in agents {
220 let mut clone = agent.clone();
221 clone.mutate();
222 children.push(clone);
223 }
224 children
225}
226
227fn create_children_from_crossover<Gene, Data>(
228 pairs: Vec<(Agent<Gene>, Agent<Gene>)>,
229 data: &Data,
230 score_provider: &mut ScoreProvider<Gene, Data>,
231) -> Vec<(Score, Agent<Gene>)>
232where
233Standard: Distribution<Gene>,
234Gene: Clone + Hash
235{
236 let mut children = Vec::new();
237
238 for (parent_one, parent_two) in pairs {
239 let child = crossover(&parent_one, &parent_two);
240 children.push(child);
241 }
242 let children = score_provider.evaluate_scores(children, data).unwrap();
243
244 let mut agents = Vec::new();
245 let mut rng = rand::thread_rng();
246 for agent in children {
247 let score_index = score_provider.get_score(&agent, data, &mut rng).unwrap();
248 agents.push((score_index, agent));
249 }
250 return agents;
251}
252
253fn get_random_subset<Gene>(
254 agents: &BTreeMap<Score, Agent<Gene>>,
255 rate: f64,
256 preferred_minimum: usize
257) -> BTreeMap<Score, &Agent<Gene>>
258where Gene: Clone
259{
260 let number = rate_to_number(agents.len(), rate, preferred_minimum);
261 let keys: Vec<Score> = agents.keys().map(|k| *k).collect();
262 let mut rng = rand::thread_rng();
263 let mut subset = BTreeMap::new();
264 for _ in 0..number {
265 let key = keys[rng.gen_range(0, keys.len())];
266 let agent = agents.get(&key);
267 if agent.is_some() {
268 subset.insert(key, agent.unwrap());
269 }
270 }
271
272 subset
273}
274
275fn get_highest_scored_agents<Gene>(
276 agents: &BTreeMap<Score, Agent<Gene>>,
277 rate: f64,
278 preferred_minimum: usize
279) -> BTreeMap<Score, &Agent<Gene>>
280where Gene: Clone
281{
282 let number = rate_to_number(agents.len(), rate, preferred_minimum);
283 let mut keys: Vec<Score> = agents.keys().map(|k| *k).collect();
284 let keys_len = keys.len();
285 keys.drain(0..(keys_len - number));
286 let mut subset = BTreeMap::new();
287 for key in keys {
288 let agent = agents.get(&key);
289 if agent.is_some() {
290 subset.insert(key, agent.unwrap());
291 }
292 }
293
294 subset
295}
296
297fn get_lowest_scored_agents<Gene>(
298 agents: &BTreeMap<Score, Agent<Gene>>,
299 rate: f64,
300 preferred_minimum: usize
301) -> BTreeMap<Score, &Agent<Gene>>
302where Gene: Clone
303{
304 let number = rate_to_number(agents.len(), rate, preferred_minimum);
305 let mut keys: Vec<Score> = agents.keys().map(|k| *k).collect();
306 keys.truncate(number);
307 let mut subset = BTreeMap::new();
308 for key in keys {
309 let agent = agents.get(&key);
310 if agent.is_some() {
311 subset.insert(key, agent.unwrap());
312 }
313 }
314
315 subset
316}
317
318fn create_random_pairs<Gene>(
319 agents: BTreeMap<Score, &Agent<Gene>>,
320) -> Vec<(Agent<Gene>, Agent<Gene>)>
321where
322Gene: Clone
323{
324 let keys: Vec<&Score> = agents.keys().collect();
325 let mut rng = rand::thread_rng();
326 let mut pairs = Vec::new();
327 let count = keys.len();
328 for _ in 0..count {
329 let one_key = keys[rng.gen_range(0, keys.len())];
330 let two_key = keys[rng.gen_range(0, keys.len())];
331
332 let one_agent = agents.get(one_key);
333 let two_agent = agents.get(two_key);
334 if one_agent.is_some() && two_agent.is_some() {
335 let one_agent = *one_agent.unwrap();
336 let two_agent = *two_agent.unwrap();
337 if !one_agent.has_same_genes(two_agent) {
338 pairs.push((one_agent.clone(), two_agent.clone()));
339 }
340 }
341 }
342
343 pairs
344}
345
346
347pub fn cull_lowest_agents<Gene>(
348 mut population: Population<Gene>,
349 rate: f64,
350 preferred_minimum: usize
351) -> Population<Gene>
352{
353 let keys: Vec<Score> = population.get_agents().keys().map(|k| *k).collect();
354 let cull_number = rate_to_number(keys.len(), rate, preferred_minimum);
355 if cull_number >= keys.len() {
356 return population;
357 }
358 population.cull_all_below(keys[cull_number]);
359 population
360}
361
362fn rate_to_number(population: usize, rate: f64, preferred_minimum: usize) -> usize {
363 if population < preferred_minimum {
364 return population;
365 }
366 let number = (population as f64 * rate) as usize;
367 if number < preferred_minimum {
368 return preferred_minimum;
369 }
370
371 number
372}
373
374#[cfg(test)]
375mod tests {
376 use super::*;
377 use super::super::fitness::{GeneralScoreProvider, ScoreError};
378
379 fn get_score_index(agent: &Agent<u8>, _data: &u8) -> Result<Score, ScoreError> {
380 let score = agent.get_genes()[0] as Score;
381 Ok(score)
382 }
383
384 #[test]
385 fn selection_random_any_returns_correct_proportion() {
386 let selection = Selection::with_values(SelectionType::RandomAny, 0.25, 0);
387
388 let population = Population::new(8, 1, false, &0, &mut GeneralScoreProvider::new(get_score_index, 25));
389
390 let agent_map = selection.agents(&population);
391 assert_eq!(2, agent_map.len());
392 }
393
394 #[test]
395 fn selection_highest_score_returns_highest() {
396 let selection = Selection::with_values(SelectionType::HighestScore, 0.25, 0);
397
398 let population = Population::new(8, 1, false, &0, &mut GeneralScoreProvider::new(get_score_index, 25));
399
400 let agent_map = selection.agents(&population);
401 assert_eq!(2, agent_map.len());
402
403 let mut iter = population.get_agents().iter().rev();
404 let (score, _) = iter.next().unwrap();
405 assert!(agent_map.contains_key(score));
406 let (score, _) = iter.next().unwrap();
407 assert!(agent_map.contains_key(score));
408 }
409
410 #[test]
411 fn selection_lowest_score_returns_lowest() {
412 let selection = Selection::with_values(SelectionType::LowestScore, 0.25, 0);
413
414 let population = Population::new(8, 1, false, &0, &mut GeneralScoreProvider::new(get_score_index, 25));
415
416 let agent_map = selection.agents(&population);
417 assert_eq!(2, agent_map.len());
418
419 let mut iter = population.get_agents().iter();
420 let (score, _) = iter.next().unwrap();
421 assert!(agent_map.contains_key(score));
422 let (score, _) = iter.next().unwrap();
423 assert!(agent_map.contains_key(score));
424 }
425
426 #[test]
427 fn rate_to_number_standard_proportion() {
428 assert_eq!(16, rate_to_number(20, 0.8, 0));
429 }
430
431 #[test]
432 fn rate_to_number_population_is_zero() {
433 assert_eq!(0, rate_to_number(0, 0.0, 0));
434 assert_eq!(0, rate_to_number(0, 0.8, 0));
435 }
436
437 #[test]
438 fn rate_to_number_full_proportion() {
439 assert_eq!(20, rate_to_number(20, 1.0, 0));
440 }
441
442 #[test]
443 fn rate_to_number_rounds_down() {
444 assert_eq!(7, rate_to_number(10, 0.75, 0));
445 assert_eq!(7, rate_to_number(10, 0.71, 0));
446 assert_eq!(7, rate_to_number(10, 0.79, 0));
447 }
448
449 #[test]
450 fn rate_to_number_minimum_preference_less_than_proportion() {
451 assert_eq!(7, rate_to_number(10, 0.7, 5));
452 }
453
454 #[test]
455 fn rate_to_number_minimum_preference_greater_than_proportion() {
456 assert_eq!(8, rate_to_number(10, 0.7, 8));
457 }
458
459 #[test]
460 fn rate_to_number_minimum_preference_greater_than_population() {
461 assert_eq!(4, rate_to_number(4, 0.5, 5));
462 }
463}