1use super::agent::Agent;
16use super::fitness::{Score, ScoreProvider};
17use std::collections::{BTreeMap, HashSet};
18use std::hash::Hash;
19use rand::{
20 distributions::{Distribution, Standard},
21 Rng,
22};
23
24#[derive(Clone)]
25pub struct Population <Gene> {
26 agents: BTreeMap<Score, Agent<Gene>>,
27 register: HashSet<u64>,
28 unique_agents: bool,
29
30}
31
32impl <Gene> Population <Gene> {
33
34 pub fn new_empty(unique: bool) -> Self {
35 Self {
36 agents: BTreeMap::new(),
37 register: HashSet::new(),
38 unique_agents: unique
39 }
40 }
41
42 pub fn new<Data, SP>(
43 start_size: usize,
44 number_of_genes: usize,
45 unique: bool,
46 data: &Data,
47 score_provider: &mut SP,
48 ) -> Population<Gene>
49 where
50 Standard: Distribution<Gene>,
51 Gene: Hash + Clone,
52 SP: ScoreProvider<Gene, Data>
53 {
54 let mut population = Population::new_empty(unique);
55 let mut rng = rand::thread_rng();
56 let mut agents = Vec::new();
57 for _ in 0..start_size {
58 let agent = Agent::with_genes(number_of_genes);
59 if population.will_accept(&agent) {
60 agents.push(agent);
61 }
62 }
63
64 let agents = score_provider.evaluate_scores(agents, &data).unwrap();
65
66 for agent in agents {
67 let mut score = score_provider.get_score(&agent, &data, &mut rng).unwrap();
68
69 loop {
70 if score == 0 {
71 break;
72 }
73 if population.contains_score(score) {
74 score -= 1;
75 } else {
76 break;
77 }
78 }
79
80 population.insert(score, agent);
81 }
82
83 population
84 }
85
86 pub fn set_agents(&mut self, agents: BTreeMap<Score, Agent<Gene>>) {
87 for (score, agent) in agents {
88 self.insert(score, agent);
89 }
90 }
91
92 pub fn insert(&mut self, score: Score, agent: Agent<Gene>) {
93 if self.unique_agents {
94 if self.register.contains(&agent.get_hash()) {
95 return;
96 }
97 self.register.insert(agent.get_hash());
98 }
99 self.agents.insert(score, agent);
100 }
101
102 pub fn remove(&mut self, score: Score) -> Option<Agent<Gene>> where Gene: Clone {
103 let agent = self.agents.remove(&score);
104 if self.unique_agents && agent.is_some() {
105 self.register.remove(&agent.clone().unwrap().get_hash());
106 }
107 agent
108 }
109
110 pub fn get(&self, score: Score) -> Option<&Agent<Gene>> {
111 self.agents.get(&score)
112 }
113
114 pub fn get_agents(&self) -> &BTreeMap<Score, Agent<Gene>> {
115 &self.agents
116 }
117
118 pub fn len(&self) -> usize {
119 self.agents.len()
120 }
121
122 pub fn cull_all_below(&mut self, score: Score) {
123 self.agents = self.agents.split_off(&score);
124 if self.unique_agents {
125 self.register.clear();
126 for (_, agent) in &self.agents {
127 self.register.insert(agent.get_hash());
128 }
129 }
130 }
131
132 pub fn cull_all_above(&mut self, score: Score) {
133 self.agents.split_off(&score);
134 if self.unique_agents {
135 self.register.clear();
136 for (_, agent) in &self.agents {
137 self.register.insert(agent.get_hash());
138 }
139 }
140 }
141
142 pub fn contains_score(&self, score: Score) -> bool {
143 self.agents.contains_key(&score)
144 }
145
146 pub fn will_accept(&self, agent: &Agent<Gene>) -> bool {
147 if self.unique_agents {
148 return !self.register.contains(&agent.get_hash());
149 }
150 true
151 }
152
153 pub fn get_scores(&self) -> Vec<Score> {
154 self.agents.keys().map(|k| *k).collect()
155 }
156
157 pub fn get_random_score(&self) -> Score {
158 let mut rng = rand::thread_rng();
159 self.get_scores()[rng.gen_range(0, self.len())]
160 }
161}
162
163#[cfg(test)]
164mod tests {
165 use super::*;
166 use super::super::fitness::{GeneralScoreProvider, ScoreError};
167
168 #[test]
169 fn new_empty() {
170 let population: Population<u8> = Population::new_empty(false);
171 assert_eq!(0, population.len());
172 assert_eq!(0, population.get_agents().len());
173 assert_eq!(0, population.get_scores().len());
174 }
175
176 fn get_score_index(agent: &Agent<u8>, _data: &u8) -> Result<Score, ScoreError> {
177 let score = agent.get_genes()[0] as Score;
178 Ok(score)
179 }
180
181 #[test]
182 fn new_with_false_unique() {
183 let mut population = Population::new(5, 6, false, &0, &mut GeneralScoreProvider::new(get_score_index, 25));
184 assert_eq!(5, population.len());
185 assert_eq!(5, population.get_agents().len());
186 assert_eq!(5, population.get_scores().len());
187 for (_score, agent) in population.get_agents() {
188 assert_eq!(6, agent.get_genes().len());
189 }
190
191 let random_score = population.get_random_score();
192 let agent = population.get(random_score).unwrap().clone();
193 assert!(population.will_accept(&agent));
194 let mut new_score = 0;
195 while population.contains_score(new_score) {
196 new_score += 1;
197 }
198
199 population.insert(new_score, agent);
200 assert_eq!(6, population.len());
201 assert_eq!(6, population.get_agents().len());
202 assert_eq!(6, population.get_scores().len());
203 }
204
205 #[test]
206 fn new_with_true_unique() {
207 let mut population = Population::new(5, 6, true, &0, &mut GeneralScoreProvider::new(get_score_index, 25));
208 assert_eq!(5, population.len());
209 assert_eq!(5, population.get_agents().len());
210 assert_eq!(5, population.get_scores().len());
211 for (_score, agent) in population.get_agents() {
212 assert_eq!(6, agent.get_genes().len());
213 }
214
215 let random_score = population.get_random_score();
216 let agent = population.get(random_score).unwrap().clone();
217 assert!(!population.will_accept(&agent));
218 let mut new_score = 0;
219 while population.contains_score(new_score) {
220 new_score += 1;
221 }
222
223 population.insert(new_score, agent.clone());
224 assert_eq!(5, population.len());
225 assert_eq!(5, population.get_agents().len());
226 assert_eq!(5, population.get_scores().len());
227
228 population.remove(random_score);
229 assert_eq!(4, population.len());
230 assert_eq!(4, population.get_agents().len());
231 assert_eq!(4, population.get_scores().len());
232
233 population.insert(new_score, agent);
234 assert_eq!(5, population.len());
235 assert_eq!(5, population.get_agents().len());
236 assert_eq!(5, population.get_scores().len());
237 }
238
239 #[test]
240 fn cull_all_below() {
241 let mut population = Population::new(5, 6, true, &0, &mut GeneralScoreProvider::new(get_score_index, 25));
242 assert_eq!(5, population.len());
243 assert_eq!(5, population.get_agents().len());
244 assert_eq!(5, population.get_scores().len());
245
246 let lowest = population.get_scores()[0];
247 let second_lowest = population.get_scores()[1];
248 let middle = population.get_scores()[2];
249 let second_highest = population.get_scores()[3];
250 let highest = population.get_scores()[4];
251
252 assert!(highest > lowest);
254
255 let lowest_clone = population.get(lowest).unwrap().clone();
257 let highest_clone = population.get(highest).unwrap().clone();
258
259 population.cull_all_below(middle);
260 assert_eq!(3, population.len());
261 assert_eq!(3, population.get_agents().len());
262 assert_eq!(3, population.get_scores().len());
263
264 assert!(!population.contains_score(lowest));
265 assert!(!population.contains_score(second_lowest));
266 assert!(population.contains_score(middle));
267 assert!(population.contains_score(second_highest));
268 assert!(population.contains_score(highest));
269
270 let mut new_score = 0;
271 while population.contains_score(new_score) {
272 new_score += 1;
273 }
274
275 assert!(!population.will_accept(&highest_clone));
277 population.insert(new_score, highest_clone);
278 assert_eq!(3, population.len());
279 assert_eq!(3, population.get_agents().len());
280 assert_eq!(3, population.get_scores().len());
281
282 assert!(population.will_accept(&lowest_clone));
284 population.insert(new_score, lowest_clone);
285 assert_eq!(4, population.len());
286 assert_eq!(4, population.get_agents().len());
287 assert_eq!(4, population.get_scores().len());
288 }
289}