aristeia/
operations.rs

1// Copyright 2019 Brendan Cox
2// 
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//    http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use super::agent::{Agent, crossover};
16use super::population::Population;
17use std::hash::Hash;
18use rand::{
19    distributions::{Distribution, Standard},
20    Rng,
21};
22use std::marker::{Send, PhantomData};
23use std::collections::BTreeMap;
24use super::fitness::{Score, ScoreProvider};
25
26
27#[derive(Clone, Copy)]
28pub enum OperationType {
29    Mutate,
30    Crossover,
31    Cull
32}
33
34#[derive(Clone, Copy)]
35pub enum SelectionType {
36    RandomAny,
37    HighestScore,
38    LowestScore
39}
40
41/// Allows definition of parameters for selecting some agents from a population.
42#[derive(Clone, Copy)]
43pub struct Selection {
44    selection_type: SelectionType,
45    proportion: f64,
46    preferred_minimum: usize
47}
48
49impl Selection {
50    pub fn with_values(selection_type: SelectionType, proportion: f64, preferred_minimum: usize) -> Self {
51        Self {
52            selection_type: selection_type,
53            proportion: proportion,
54            preferred_minimum: preferred_minimum
55        }
56    }
57
58    pub fn new(selection_type: SelectionType, proportion: f64) -> Self {
59        Self {
60            selection_type: selection_type,
61            proportion: proportion,
62            preferred_minimum: 1
63        }
64    }
65
66    pub fn selection_type(&self) -> SelectionType {
67        self.selection_type
68    }
69
70    pub fn proportion(&self) -> f64 {
71        self.proportion
72    }
73
74    pub fn preferred_minimum(&self) -> usize {
75        self.preferred_minimum
76    }
77
78    pub fn agents <'a, Gene> (&self, population: &'a Population<Gene>) -> BTreeMap<Score, &'a Agent<Gene>>
79    where
80    Gene: Clone
81    {
82        match self.selection_type {
83            SelectionType::RandomAny => get_random_subset(population.get_agents(), self.proportion, self.preferred_minimum),
84            SelectionType::HighestScore => get_highest_scored_agents(population.get_agents(), self.proportion, self.preferred_minimum),
85            SelectionType::LowestScore => get_lowest_scored_agents(population.get_agents(), self.proportion, self.preferred_minimum)
86        }
87    }
88
89    pub fn count <Gene> (&self, population: &Population<Gene>) -> usize {
90        rate_to_number(population.len(), self.proportion, self.preferred_minimum)
91    }
92}
93
94/// Modifies a selection of a population.
95#[derive(Clone)]
96pub struct Operation <Gene, Data>
97where
98Standard: Distribution<Gene>,
99Gene: Clone + Hash + Send + 'static,
100Data: Clone + Send + 'static
101{
102    selection: Selection,
103    operation_type: OperationType,
104    gene: PhantomData<Gene>,
105    data: PhantomData<Data>
106}
107
108impl <Gene, Data> Operation <Gene, Data>
109where
110Standard: Distribution<Gene>,
111Gene: Clone + Hash + Send + 'static,
112Data: Clone + Send + 'static
113{
114    pub fn with_values(
115        selection: Selection,
116        operation_type: OperationType
117        ) -> Self {
118        Self {
119            selection: selection,
120            operation_type: operation_type,
121            gene: PhantomData,
122            data: PhantomData
123        }
124    }
125
126    pub fn new(
127        operation_type: OperationType,
128        selection: Selection
129    ) -> Self {
130        Self {
131            selection: selection,
132            operation_type: operation_type,
133            gene: PhantomData,
134            data: PhantomData
135        }
136    }
137
138    pub fn run (&self, population: Population<Gene>, data: &Data, score_provider: &mut ScoreProvider<Gene, Data>) -> Population<Gene>
139    {
140        match self.operation_type {
141            OperationType::Mutate => mutate_agents(population, self.selection, data, score_provider),
142            OperationType::Crossover => crossover_agents(population, self.selection, data, score_provider),
143            OperationType::Cull => cull_agents(population, self.selection)
144        }
145    }
146}
147
148fn mutate_agents<Gene, Data>(
149    mut population: Population<Gene>,
150    selection: Selection,
151    data: &Data,
152    score_provider: &mut ScoreProvider<Gene, Data>
153) -> Population<Gene>
154where
155Standard: Distribution<Gene>,
156Gene: Clone + Hash + Send + 'static,
157Data: Clone + Send + 'static
158{
159    let children = get_mutated_agents(selection.agents(&population));
160    let children = score_provider.evaluate_scores(children, data).unwrap();
161    let mut rng = rand::thread_rng();
162    for agent in children {
163        let score_index = score_provider.get_score(&agent, data, &mut rng).unwrap();
164        population.insert(score_index, agent);
165    }
166
167    population
168}
169
170fn crossover_agents<Gene, Data>(
171    mut population: Population<Gene>,
172    selection: Selection,
173    data: &Data,
174    score_provider: &mut ScoreProvider<Gene, Data>
175) -> Population<Gene>
176where
177Standard: Distribution<Gene>,
178Gene: Clone + Hash + Send + 'static,
179Data: Clone + Send + 'static
180{
181    let pairs = create_random_pairs(
182        selection.agents(&population)
183    );
184
185    let children = create_children_from_crossover(pairs, data, score_provider);
186    for (score_index, agent) in children {
187        population.insert(score_index, agent);
188    }
189
190    population
191}
192
193fn cull_agents<Gene>(
194    mut population: Population<Gene>,
195    selection: Selection,
196) -> Population<Gene>
197{
198    let keys: Vec<Score> = population.get_agents().keys().map(|k| *k).collect();
199    let cull_number = selection.count(&population);
200    if cull_number >= keys.len() {
201        return population;
202    }
203    
204    match selection.selection_type() {
205        SelectionType::LowestScore => population.cull_all_below(keys[cull_number]),
206        SelectionType::HighestScore => population.cull_all_above(keys[cull_number]),
207        SelectionType::RandomAny => panic!("RandomAny selection not yet implemented for cull agents")
208    };
209    population
210}
211
212fn get_mutated_agents<Gene>(
213    agents: BTreeMap<Score, &Agent<Gene>>,
214) -> Vec<Agent<Gene>>
215where Standard: Distribution<Gene>,
216Gene: Clone + Hash + Send
217{
218    let mut children = Vec::new();
219    for (_, mut agent) in agents {
220        let mut clone = agent.clone();
221        clone.mutate();
222        children.push(clone);
223    }
224    children
225}
226
227fn create_children_from_crossover<Gene, Data>(
228    pairs: Vec<(Agent<Gene>, Agent<Gene>)>,
229    data: &Data,
230    score_provider: &mut ScoreProvider<Gene, Data>,
231) -> Vec<(Score, Agent<Gene>)>
232where
233Standard: Distribution<Gene>,
234Gene: Clone + Hash
235{
236    let mut children = Vec::new();
237
238    for (parent_one, parent_two) in pairs {
239        let child = crossover(&parent_one, &parent_two);
240        children.push(child);
241    }
242    let children = score_provider.evaluate_scores(children, data).unwrap();
243
244    let mut agents = Vec::new();
245    let mut rng = rand::thread_rng();
246    for agent in children {
247        let score_index = score_provider.get_score(&agent, data, &mut rng).unwrap();
248        agents.push((score_index, agent));
249    }
250    return agents;
251}
252
253fn get_random_subset<Gene>(
254    agents: &BTreeMap<Score, Agent<Gene>>,
255    rate: f64,
256    preferred_minimum: usize
257) -> BTreeMap<Score, &Agent<Gene>>
258where Gene: Clone
259{
260    let number = rate_to_number(agents.len(), rate, preferred_minimum);
261    let keys: Vec<Score> = agents.keys().map(|k| *k).collect();
262    let mut rng = rand::thread_rng();
263    let mut subset = BTreeMap::new();
264    for _ in 0..number {
265        let key = keys[rng.gen_range(0, keys.len())];
266        let agent = agents.get(&key);
267        if agent.is_some() {
268            subset.insert(key, agent.unwrap());
269        }
270    }
271
272    subset
273}
274
275fn get_highest_scored_agents<Gene>(
276    agents: &BTreeMap<Score, Agent<Gene>>,
277    rate: f64,
278    preferred_minimum: usize
279) -> BTreeMap<Score, &Agent<Gene>>
280where Gene: Clone
281{
282    let number = rate_to_number(agents.len(), rate, preferred_minimum);
283    let mut keys: Vec<Score> = agents.keys().map(|k| *k).collect();
284    let keys_len = keys.len();
285    keys.drain(0..(keys_len - number));
286    let mut subset = BTreeMap::new();
287    for key in keys {
288        let agent = agents.get(&key);
289        if agent.is_some() {
290            subset.insert(key, agent.unwrap());
291        }
292    }
293
294    subset
295}
296
297fn get_lowest_scored_agents<Gene>(
298    agents: &BTreeMap<Score, Agent<Gene>>,
299    rate: f64,
300    preferred_minimum: usize
301) -> BTreeMap<Score, &Agent<Gene>>
302where Gene: Clone
303{
304    let number = rate_to_number(agents.len(), rate, preferred_minimum);
305    let mut keys: Vec<Score> = agents.keys().map(|k| *k).collect();
306    keys.truncate(number);
307    let mut subset = BTreeMap::new();
308    for key in keys {
309        let agent = agents.get(&key);
310        if agent.is_some() {
311            subset.insert(key, agent.unwrap());
312        }
313    }
314
315    subset
316}
317
318fn create_random_pairs<Gene>(
319    agents: BTreeMap<Score, &Agent<Gene>>,
320) -> Vec<(Agent<Gene>, Agent<Gene>)> 
321where
322Gene: Clone
323{
324    let keys: Vec<&Score> = agents.keys().collect();
325    let mut rng = rand::thread_rng();
326    let mut pairs = Vec::new();
327    let count = keys.len();
328    for _ in 0..count {
329        let one_key = keys[rng.gen_range(0, keys.len())];
330        let two_key = keys[rng.gen_range(0, keys.len())];
331
332        let one_agent = agents.get(one_key);
333        let two_agent = agents.get(two_key);
334        if one_agent.is_some() && two_agent.is_some() {
335            let one_agent = *one_agent.unwrap();
336            let two_agent = *two_agent.unwrap();
337            if !one_agent.has_same_genes(two_agent) {
338                pairs.push((one_agent.clone(), two_agent.clone()));
339            }
340        }
341    }
342
343    pairs
344}
345
346
347pub fn cull_lowest_agents<Gene>(
348    mut population: Population<Gene>,
349    rate: f64,
350    preferred_minimum: usize
351) -> Population<Gene>
352{
353    let keys: Vec<Score> = population.get_agents().keys().map(|k| *k).collect();
354    let cull_number = rate_to_number(keys.len(), rate, preferred_minimum);
355    if cull_number >= keys.len() {
356        return population;
357    }
358    population.cull_all_below(keys[cull_number]);
359    population
360}
361
362fn rate_to_number(population: usize, rate: f64, preferred_minimum: usize) -> usize {
363    if population < preferred_minimum {
364        return population;
365    }
366    let number = (population as f64 * rate) as usize;
367    if number < preferred_minimum {
368        return preferred_minimum;
369    }
370
371    number
372}
373
374#[cfg(test)]
375mod tests {
376    use super::*;
377    use super::super::fitness::{GeneralScoreProvider, ScoreError};
378
379    fn get_score_index(agent: &Agent<u8>, _data: &u8) -> Result<Score, ScoreError> {
380        let score = agent.get_genes()[0] as Score;
381        Ok(score)
382    }
383
384    #[test]
385    fn selection_random_any_returns_correct_proportion() {
386        let selection = Selection::with_values(SelectionType::RandomAny, 0.25, 0);
387
388        let population = Population::new(8, 1, false, &0, &mut GeneralScoreProvider::new(get_score_index, 25));
389
390        let agent_map = selection.agents(&population);
391        assert_eq!(2, agent_map.len());
392    }
393
394    #[test]
395    fn selection_highest_score_returns_highest() {
396        let selection = Selection::with_values(SelectionType::HighestScore, 0.25, 0);
397
398        let population = Population::new(8, 1, false, &0, &mut GeneralScoreProvider::new(get_score_index, 25));
399
400        let agent_map = selection.agents(&population);
401        assert_eq!(2, agent_map.len());
402
403        let mut iter = population.get_agents().iter().rev();
404        let (score, _) = iter.next().unwrap();
405        assert!(agent_map.contains_key(score));
406        let (score, _) = iter.next().unwrap();
407        assert!(agent_map.contains_key(score));
408    }
409
410    #[test]
411    fn selection_lowest_score_returns_lowest() {
412        let selection = Selection::with_values(SelectionType::LowestScore, 0.25, 0);
413
414        let population = Population::new(8, 1, false, &0, &mut GeneralScoreProvider::new(get_score_index, 25));
415
416        let agent_map = selection.agents(&population);
417        assert_eq!(2, agent_map.len());
418
419        let mut iter = population.get_agents().iter();
420        let (score, _) = iter.next().unwrap();
421        assert!(agent_map.contains_key(score));
422        let (score, _) = iter.next().unwrap();
423        assert!(agent_map.contains_key(score));
424    }
425
426    #[test]
427    fn rate_to_number_standard_proportion() {
428        assert_eq!(16, rate_to_number(20, 0.8, 0));
429    }
430
431    #[test]
432    fn rate_to_number_population_is_zero() {
433        assert_eq!(0, rate_to_number(0, 0.0, 0));
434        assert_eq!(0, rate_to_number(0, 0.8, 0));
435    }
436
437    #[test]
438    fn rate_to_number_full_proportion() {
439        assert_eq!(20, rate_to_number(20, 1.0, 0));
440    }
441
442    #[test]
443    fn rate_to_number_rounds_down() {
444        assert_eq!(7, rate_to_number(10, 0.75, 0));
445        assert_eq!(7, rate_to_number(10, 0.71, 0));
446        assert_eq!(7, rate_to_number(10, 0.79, 0));
447    }
448
449    #[test]
450    fn rate_to_number_minimum_preference_less_than_proportion() {
451        assert_eq!(7, rate_to_number(10, 0.7, 5));
452    }
453
454    #[test]
455    fn rate_to_number_minimum_preference_greater_than_proportion() {
456        assert_eq!(8, rate_to_number(10, 0.7, 8));
457    }
458
459    #[test]
460    fn rate_to_number_minimum_preference_greater_than_population() {
461        assert_eq!(4, rate_to_number(4, 0.5, 5));
462    }
463}