xu/
population.rs

1// Copyright 2019 Brendan Cox
2// 
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//    http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15use super::agent::Agent;
16use super::fitness::{Score, ScoreProvider};
17use std::collections::{BTreeMap, HashSet};
18use std::hash::Hash;
19use rand::{
20    distributions::{Distribution, Standard},
21    Rng,
22};
23
24#[derive(Clone)]
25pub struct Population <Gene> {
26    agents: BTreeMap<Score, Agent<Gene>>,
27    register: HashSet<u64>,
28    unique_agents: bool,
29
30}
31
32impl <Gene> Population <Gene> {
33
34    pub fn new_empty(unique: bool) -> Self {
35        Self {
36            agents: BTreeMap::new(),
37            register: HashSet::new(),
38            unique_agents: unique
39        }
40    }
41
42    pub fn new<Data, SP>(
43        start_size: usize,
44        number_of_genes: usize,
45        unique: bool,
46        data: &Data,
47        score_provider: &mut SP,
48    ) -> Population<Gene> 
49    where
50    Standard: Distribution<Gene>,
51    Gene: Hash + Clone,
52    SP: ScoreProvider<Gene, Data>
53    {
54        let mut population = Population::new_empty(unique);
55        let mut rng = rand::thread_rng();
56        let mut agents = Vec::new();
57        for _ in 0..start_size {
58            let agent = Agent::with_genes(number_of_genes);
59            if population.will_accept(&agent) {
60                agents.push(agent);
61            }
62        }
63
64        let agents = score_provider.evaluate_scores(agents, &data).unwrap();
65
66        for agent in agents {
67            let mut score = score_provider.get_score(&agent, &data, &mut rng).unwrap();
68
69            loop {
70                if score == 0 {
71                    break;
72                }
73                if population.contains_score(score) {
74                    score -= 1;
75                } else {
76                    break;
77                }
78            }
79
80            population.insert(score, agent);
81        }
82
83        population
84    }
85
86    pub fn set_agents(&mut self, agents: BTreeMap<Score, Agent<Gene>>) {
87        for (score, agent) in agents {
88            self.insert(score, agent);
89        }
90    }
91
92    pub fn insert(&mut self, score: Score, agent: Agent<Gene>) {
93        if self.unique_agents {
94            if self.register.contains(&agent.get_hash()) {
95                return;
96            }
97            self.register.insert(agent.get_hash());
98        }
99        self.agents.insert(score, agent);
100    }
101
102    pub fn remove(&mut self, score: Score) -> Option<Agent<Gene>> where Gene: Clone {
103        let agent = self.agents.remove(&score);
104        if self.unique_agents && agent.is_some() {
105            self.register.remove(&agent.clone().unwrap().get_hash());
106        }
107        agent
108    }
109
110    pub fn get(&self, score: Score) -> Option<&Agent<Gene>> {
111        self.agents.get(&score)
112    }
113
114    pub fn get_agents(&self) -> &BTreeMap<Score, Agent<Gene>> {
115        &self.agents
116    }
117
118    pub fn len(&self) -> usize {
119        self.agents.len()
120    }
121
122    pub fn cull_all_below(&mut self, score: Score) {
123        self.agents = self.agents.split_off(&score);
124        if self.unique_agents {
125            self.register.clear();
126            for (_, agent) in &self.agents {
127                self.register.insert(agent.get_hash());
128            }
129        }
130    }
131
132    pub fn cull_all_above(&mut self, score: Score) {
133        self.agents.split_off(&score);
134        if self.unique_agents {
135            self.register.clear();
136            for (_, agent) in &self.agents {
137                self.register.insert(agent.get_hash());
138            }
139        }
140    }
141
142    pub fn contains_score(&self, score: Score) -> bool {
143        self.agents.contains_key(&score)
144    }
145
146    pub fn will_accept(&self, agent: &Agent<Gene>) -> bool {
147        if self.unique_agents {
148            return !self.register.contains(&agent.get_hash());
149        }
150        true
151    }
152
153    pub fn get_scores(&self) -> Vec<Score> {
154        self.agents.keys().map(|k| *k).collect()
155    }
156
157    pub fn get_random_score(&self) -> Score {
158        let mut rng = rand::thread_rng();
159        self.get_scores()[rng.gen_range(0, self.len())]
160    }
161}
162
163#[cfg(test)]
164mod tests {
165    use super::*;
166    use super::super::fitness::{GeneralScoreProvider, ScoreError};
167
168    #[test]
169    fn new_empty() {
170        let population: Population<u8> = Population::new_empty(false);
171        assert_eq!(0, population.len());
172        assert_eq!(0, population.get_agents().len());
173        assert_eq!(0, population.get_scores().len());
174    }
175
176    fn get_score_index(agent: &Agent<u8>, _data: &u8) -> Result<Score, ScoreError> {
177        let score = agent.get_genes()[0] as Score;
178        Ok(score)
179    }
180
181    #[test]
182    fn new_with_false_unique() {
183        let mut population = Population::new(5, 6, false, &0, &mut GeneralScoreProvider::new(get_score_index, 25));
184        assert_eq!(5, population.len());
185        assert_eq!(5, population.get_agents().len());
186        assert_eq!(5, population.get_scores().len());
187        for (_score, agent) in population.get_agents() {
188            assert_eq!(6, agent.get_genes().len());
189        }
190
191        let random_score = population.get_random_score();
192        let agent = population.get(random_score).unwrap().clone();
193        assert!(population.will_accept(&agent));
194        let mut new_score = 0;
195        while population.contains_score(new_score) {
196            new_score += 1;
197        }
198
199        population.insert(new_score, agent);
200        assert_eq!(6, population.len());
201        assert_eq!(6, population.get_agents().len());
202        assert_eq!(6, population.get_scores().len());
203    }
204
205    #[test]
206    fn new_with_true_unique() {
207        let mut population = Population::new(5, 6, true, &0, &mut GeneralScoreProvider::new(get_score_index, 25));
208        assert_eq!(5, population.len());
209        assert_eq!(5, population.get_agents().len());
210        assert_eq!(5, population.get_scores().len());
211        for (_score, agent) in population.get_agents() {
212            assert_eq!(6, agent.get_genes().len());
213        }
214
215        let random_score = population.get_random_score();
216        let agent = population.get(random_score).unwrap().clone();
217        assert!(!population.will_accept(&agent));
218        let mut new_score = 0;
219        while population.contains_score(new_score) {
220            new_score += 1;
221        }
222
223        population.insert(new_score, agent.clone());
224        assert_eq!(5, population.len());
225        assert_eq!(5, population.get_agents().len());
226        assert_eq!(5, population.get_scores().len());
227
228        population.remove(random_score);
229        assert_eq!(4, population.len());
230        assert_eq!(4, population.get_agents().len());
231        assert_eq!(4, population.get_scores().len());
232
233        population.insert(new_score, agent);
234        assert_eq!(5, population.len());
235        assert_eq!(5, population.get_agents().len());
236        assert_eq!(5, population.get_scores().len());
237    }
238
239    #[test]
240    fn cull_all_below() {
241        let mut population = Population::new(5, 6, true, &0, &mut GeneralScoreProvider::new(get_score_index, 25));
242        assert_eq!(5, population.len());
243        assert_eq!(5, population.get_agents().len());
244        assert_eq!(5, population.get_scores().len());
245
246        let lowest = population.get_scores()[0];
247        let second_lowest = population.get_scores()[1];
248        let middle = population.get_scores()[2];
249        let second_highest = population.get_scores()[3];
250        let highest = population.get_scores()[4];
251        
252        // Ensure ordering is as expected.
253        assert!(highest > lowest);
254
255        // Will be used for checking register of hashes was updated.
256        let lowest_clone = population.get(lowest).unwrap().clone();
257        let highest_clone = population.get(highest).unwrap().clone();
258
259        population.cull_all_below(middle);
260        assert_eq!(3, population.len());
261        assert_eq!(3, population.get_agents().len());
262        assert_eq!(3, population.get_scores().len());
263
264        assert!(!population.contains_score(lowest));
265        assert!(!population.contains_score(second_lowest));
266        assert!(population.contains_score(middle));
267        assert!(population.contains_score(second_highest));
268        assert!(population.contains_score(highest));
269
270        let mut new_score = 0;
271        while population.contains_score(new_score) {
272            new_score += 1;
273        }
274
275        // The highest is still in there and so its clone should not be accepted.
276        assert!(!population.will_accept(&highest_clone));
277        population.insert(new_score, highest_clone);
278        assert_eq!(3, population.len());
279        assert_eq!(3, population.get_agents().len());
280        assert_eq!(3, population.get_scores().len());
281
282        // The lowest is no longer there and so its clone can be accepted.
283        assert!(population.will_accept(&lowest_clone));
284        population.insert(new_score, lowest_clone);
285        assert_eq!(4, population.len());
286        assert_eq!(4, population.get_agents().len());
287        assert_eq!(4, population.get_scores().len());
288    }
289}