roulette_wheel/lib.rs
1//! A Little implementation of the roulette-wheel principle, `RouletteWheel<T>`.
2//! https://wikipedia.org/wiki/Fitness_proportionate_selection
3//!
4//! 
5//!
6//! # Examples usages
7//!
8//! ```
9//! use roulette_wheel::RouletteWheel;
10//!
11//! fn evaluate(individual: &i32) -> f32 { *individual as f32 } // mmm...!
12//!
13//! let population: Vec<_> = (1..10).into_iter().collect();
14//! let fitnesses: Vec<_> = population.iter().map(|ind| evaluate(ind)).collect();
15//!
16//! let rw: RouletteWheel<_> = fitnesses.into_iter().zip(population).collect();
17//!
18//! // let's collect the individuals in the order in which the roulette wheel gives them
19//! let individuals: Vec<_> = rw.into_iter().map(|(_, ind)| ind).collect();
20//! // rw.select_iter() will not consume the roulette wheel
21//! // while rw.into_iter() will !
22//!
23//! fn crossover(mother: &i32, _father: &i32) -> i32 { mother.clone() } // unimplemented!()
24//!
25//! // now merge each individual by couples
26//! let new_population: Vec<_> = individuals.chunks(2)
27//! .filter(|couple| couple.len() == 2)
28//! .map(|couple| {
29//! let (mother, father) = (couple[0], couple[1]);
30//! crossover(&mother, &father)
31//! // note: for this example we return only one individual,
32//! // the population will shrink
33//! // .flat_map() can resolve this issue
34//! }).collect();
35//! ```
36
37extern crate rand;
38
39use std::iter::{FromIterator, Iterator, IntoIterator};
40use rand::{Rng, ThreadRng, thread_rng};
41use rand::distributions::{Range, IndependentSample};
42
43/// A roulette-wheel container
44pub struct RouletteWheel<T> {
45 total_fitness: f32,
46 fitnesses: Vec<f32>,
47 population: Vec<T>
48}
49
50impl<T: Clone> Clone for RouletteWheel<T> {
51 fn clone(&self) -> RouletteWheel<T> {
52 RouletteWheel {
53 total_fitness: self.total_fitness,
54 fitnesses: self.fitnesses.clone(),
55 population: self.population.clone()
56 }
57 }
58}
59
60impl<T> FromIterator<(f32, T)> for RouletteWheel<T> {
61 fn from_iter<A>(iter: A) -> Self where A: IntoIterator<Item=(f32, T)> {
62 let (fitnesses, population): (Vec<f32>, _) = iter.into_iter().unzip();
63 let total_fitness = fitnesses.iter().sum();
64 RouletteWheel {
65 total_fitness: total_fitness,
66 fitnesses: fitnesses,
67 population: population
68 }
69 }
70}
71
72impl<T> RouletteWheel<T> {
73 /// create a new empty random-wheel.
74 /// # Example
75 ///
76 /// ```
77 /// use roulette_wheel::RouletteWheel;
78 ///
79 /// let rw = RouletteWheel::<u8>::new();
80 /// ```
81 pub fn new() -> RouletteWheel<T> {
82 RouletteWheel {
83 total_fitness: 0.0,
84 fitnesses: Vec::new(),
85 population: Vec::new()
86 }
87 }
88
89 /// Creates an empty RouletteWheel with space for at least n elements.
90 /// # Example
91 ///
92 /// ```
93 /// use roulette_wheel::RouletteWheel;
94 ///
95 /// let rw = RouletteWheel::<u8>::with_capacity(15);
96 ///
97 /// assert_eq!(rw.len(), 0);
98 /// ```
99 pub fn with_capacity(cap: usize) -> RouletteWheel<T> {
100 RouletteWheel {
101 total_fitness: 0.0,
102 fitnesses: Vec::with_capacity(cap),
103 population: Vec::with_capacity(cap)
104 }
105 }
106
107 /// Reserves capacity for at least `additional` more elements to be inserted.
108 /// The collection may reserve more space to avoid frequent reallocations.
109 /// # Example
110 ///
111 /// ```
112 /// use roulette_wheel::RouletteWheel;
113 ///
114 /// let mut rw = RouletteWheel::<u8>::new();
115 /// rw.reserve(20);
116 ///
117 /// assert_eq!(rw.len(), 0);
118 /// ```
119 pub fn reserve(&mut self, additional: usize) {
120 self.fitnesses.reserve(additional);
121 self.population.reserve(additional);
122 }
123
124 /// Returns the number of elements in the wheel.
125 /// # Example
126 ///
127 /// ```
128 /// use roulette_wheel::RouletteWheel;
129 ///
130 /// let rw: RouletteWheel<_> = [(0.1, 10), (0.2, 15), (0.5, 20)].iter().cloned().collect();
131 ///
132 /// assert_eq!(rw.len(), 3);
133 /// ```
134 pub fn len(&self) -> usize {
135 self.population.len()
136 }
137
138 /// Returns `true` if empty else return `false`.
139 /// # Example
140 ///
141 /// ```
142 /// use roulette_wheel::RouletteWheel;
143 ///
144 /// let empty_rw = RouletteWheel::<u8>::new();
145 ///
146 /// assert_eq!(empty_rw.is_empty(), true);
147 ///
148 /// let non_empty_rw: RouletteWheel<_> = [(0.1, 10), (0.2, 15), (0.5, 20)].iter().cloned().collect();
149 ///
150 /// assert_eq!(non_empty_rw.is_empty(), false);
151 /// ```
152 pub fn is_empty(&self) -> bool {
153 self.population.is_empty()
154 }
155
156 /// Remove all elements in this wheel.
157 /// # Example
158 ///
159 /// ```
160 /// use roulette_wheel::RouletteWheel;
161 ///
162 /// let mut rw: RouletteWheel<_> = [(0.1, 10), (0.2, 15), (0.5, 20)].iter().cloned().collect();
163 ///
164 /// assert_eq!(rw.len(), 3);
165 ///
166 /// rw.clear();
167 ///
168 /// assert_eq!(rw.len(), 0);
169 /// ```
170 pub fn clear(&mut self) {
171 self.fitnesses.clear();
172 self.population.clear();
173 }
174
175 /// Add an element associated with a probability.
176 ///
177 /// # Panics
178 ///
179 /// This function might panic if the fitness is less than zero
180 /// or if the total fitness gives a non-finite fitness (`Inf`).
181 ///
182 /// # Example
183 ///
184 /// ```
185 /// use roulette_wheel::RouletteWheel;
186 ///
187 /// let mut rw = RouletteWheel::new();
188 ///
189 /// rw.push(1.0, 'r');
190 /// rw.push(1.0, 'c');
191 /// rw.push(1.0, 'a');
192 ///
193 /// assert_eq!(rw.len(), 3);
194 /// ```
195 pub fn push(&mut self, fitness: f32, individual: T) {
196 assert!(fitness >= 0.0, "Can't push the less than zero fitness: {:?}", fitness);
197 assert!((self.total_fitness + fitness).is_finite(), "Fitnesses sum reached a non-finite value!");
198 unsafe { self.unchecked_push(fitness, individual) }
199 }
200
201 /// Add an element associated with a probability.
202 /// This unsafe function doesn't check for total fitness overflow
203 /// nether fitness positivity.
204 /// # Example
205 ///
206 /// ```
207 /// use roulette_wheel::RouletteWheel;
208 ///
209 /// let mut rw = RouletteWheel::new();
210 ///
211 /// unsafe { rw.unchecked_push(1.0, 'r') };
212 /// unsafe { rw.unchecked_push(1.0, 'c') };
213 /// unsafe { rw.unchecked_push(1.0, 'a') };
214 ///
215 /// assert_eq!(rw.len(), 3);
216 /// ```
217 pub unsafe fn unchecked_push(&mut self, fitness: f32, individual: T) {
218 self.total_fitness += fitness;
219 self.fitnesses.push(fitness);
220 self.population.push(individual);
221 }
222
223 /// Returns the sum of all individual fitnesses.
224 /// # Example
225 ///
226 /// ```
227 /// use roulette_wheel::RouletteWheel;
228 ///
229 /// let mut rw = RouletteWheel::new();
230 ///
231 /// rw.push(3.0, 'r');
232 /// rw.push(2.0, 'c');
233 /// rw.push(1.5, 'a');
234 ///
235 /// assert_eq!(rw.total_fitness(), 6.5);
236 /// ```
237 pub fn total_fitness(&self) -> f32 {
238 self.total_fitness
239 }
240
241 /// Returns an iterator over the RouletteWheel.
242 ///
243 /// # Examples
244 ///
245 /// ``` ignore
246 /// use roulette_wheel::RouletteWheel;
247 ///
248 /// let rw: RouletteWheel<_> = [(0.1, 10), (0.2, 15), (0.5, 20)].iter().cloned().collect();
249 /// let mut iterator = rw.select_iter();
250 ///
251 /// assert_eq!(iterator.next(), Some((0.5, &20)));
252 /// assert_eq!(iterator.next(), Some((0.1, &10)));
253 /// assert_eq!(iterator.next(), Some((0.2, &15)));
254 /// assert_eq!(iterator.next(), None);
255 /// ```
256 pub fn select_iter(&self) -> SelectIter<ThreadRng, T> {
257 SelectIter::<ThreadRng, _>::new(&self)
258 }
259}
260
261/// Immutable RouletteWheel iterator
262///
263/// This struct is created by the [`select_iter`].
264///
265/// [`iter`]: struct.RouletteWheel.html#method.select_iter
266pub struct SelectIter<'a, R: Rng, T: 'a> {
267 distribution_range: Range<f32>,
268 rng: R,
269 total_fitness: f32,
270 fitnesses_ids: Vec<(usize, f32)>,
271 roulette_wheel: &'a RouletteWheel<T>
272}
273
274impl<'a, R: Rng, T> SelectIter<'a, R, T> {
275 pub fn new(roulette_wheel: &'a RouletteWheel<T>) -> SelectIter<'a, ThreadRng, T> {
276 SelectIter::from_rng(roulette_wheel, thread_rng())
277 }
278
279 pub fn from_rng(roulette_wheel: &'a RouletteWheel<T>, rng: R) -> SelectIter<'a, R, T> {
280 SelectIter {
281 distribution_range: Range::new(0.0, 1.0),
282 rng: rng,
283 total_fitness: roulette_wheel.total_fitness,
284 fitnesses_ids: roulette_wheel.fitnesses.iter().cloned().enumerate().collect(),
285 roulette_wheel: roulette_wheel
286 }
287 }
288}
289
290impl<'a, R: Rng, T: 'a> Iterator for SelectIter<'a, R, T> {
291 type Item = (f32, &'a T);
292
293 fn size_hint(&self) -> (usize, Option<usize>) {
294 (self.fitnesses_ids.len(), Some(self.fitnesses_ids.len()))
295 }
296
297 fn next(&mut self) -> Option<Self::Item> {
298 if !self.fitnesses_ids.is_empty() {
299 let sample = self.distribution_range.ind_sample(&mut self.rng);
300 let mut selection = sample * self.total_fitness;
301 let index = self.fitnesses_ids.iter().position(|&(_, fit)| {
302 selection -= fit;
303 selection <= 0.0
304 }).unwrap();
305 let (index, fitness) = self.fitnesses_ids.swap_remove(index);
306 self.total_fitness -= fitness;
307 Some((fitness, &self.roulette_wheel.population[index]))
308 }
309 else { None }
310 }
311}
312
313impl<T> IntoIterator for RouletteWheel<T> {
314 type Item = (f32, T);
315 type IntoIter = IntoSelectIter<ThreadRng, T>;
316
317 fn into_iter(self) -> IntoSelectIter<ThreadRng, T> {
318 IntoSelectIter::<ThreadRng, _>::new(self)
319 }
320}
321
322/// An iterator that moves out of a RouletteWheel.
323///
324/// This `struct` is created by the `into_iter` method on [`RouletteWheel`][`RouletteWheel`] (provided
325/// by the [`IntoIterator`] trait).
326///
327/// [`RouletteWheel`]: struct.RouletteWheel.html
328/// [`IntoIterator`]: https://doc.rust-lang.org/std/iter/trait.IntoIterator.html
329pub struct IntoSelectIter<R: Rng, T> {
330 distribution_range: Range<f32>,
331 rng: R,
332 total_fitness: f32,
333 fitnesses: Vec<f32>,
334 population: Vec<T>
335}
336
337impl<R: Rng, T> IntoSelectIter<R, T> {
338 pub fn new(roulette_wheel: RouletteWheel<T>) -> IntoSelectIter<ThreadRng, T> {
339 IntoSelectIter::from_rng(roulette_wheel, thread_rng())
340 }
341
342 pub fn from_rng(roulette_wheel: RouletteWheel<T>, rng: R) -> IntoSelectIter<R, T> {
343 IntoSelectIter {
344 distribution_range: Range::new(0.0, 1.0),
345 rng: rng,
346 total_fitness: roulette_wheel.total_fitness,
347 fitnesses: roulette_wheel.fitnesses,
348 population: roulette_wheel.population
349 }
350 }
351}
352
353impl<R: Rng, T> Iterator for IntoSelectIter<R, T> {
354 type Item = (f32, T);
355
356 fn size_hint(&self) -> (usize, Option<usize>) {
357 (self.fitnesses.len(), Some(self.fitnesses.len()))
358 }
359
360 fn next(&mut self) -> Option<Self::Item> {
361 if !self.fitnesses.is_empty() {
362 let sample = self.distribution_range.ind_sample(&mut self.rng);
363 let mut selection = sample * self.total_fitness;
364 let index = self.fitnesses.iter().position(|fit| {
365 selection -= *fit;
366 selection <= 0.0
367 }).unwrap();
368 let fitness = self.fitnesses.swap_remove(index);
369 let individual = self.population.swap_remove(index);
370 self.total_fitness -= fitness;
371 Some((fitness, individual))
372 }
373 else { None }
374 }
375}
376
377#[cfg(test)]
378mod tests {
379 use rand::SeedableRng;
380 use rand::StdRng;
381 use {RouletteWheel, SelectIter, IntoSelectIter};
382
383 const SEED: [usize; 4] = [4, 2, 42, 4242];
384
385 #[test]
386 fn test_select_iter_seeded() {
387 let rng = StdRng::from_seed(&SEED);
388
389 let fitnesses = [0.1, 0.2, 0.3, 0.4, 0.5];
390 let fitnesses = fitnesses.iter().cloned();
391 let population = 15..20;
392 let rw: RouletteWheel<_> = fitnesses.zip(population).collect();
393
394 let mut iter = SelectIter::from_rng(&rw, rng);
395
396 assert_eq!(iter.next(), Some((0.5, &19)));
397 assert_eq!(iter.next(), Some((0.3, &17)));
398 assert_eq!(iter.next(), Some((0.4, &18)));
399 assert_eq!(iter.next(), Some((0.2, &16)));
400 assert_eq!(iter.next(), Some((0.1, &15)));
401 assert_eq!(iter.next(), None);
402 }
403
404 #[test]
405 fn test_into_select_iter_seeded() {
406 let rng = StdRng::from_seed(&SEED);
407
408 let fitnesses = [0.1, 0.2, 0.3, 0.4, 0.5];
409 let fitnesses = fitnesses.iter().cloned();
410 let population = 15..20;
411 let rw: RouletteWheel<_> = fitnesses.zip(population).collect();
412
413 let mut iter = IntoSelectIter::from_rng(rw, rng);
414
415 assert_eq!(iter.next(), Some((0.5, 19)));
416 assert_eq!(iter.next(), Some((0.3, 17)));
417 assert_eq!(iter.next(), Some((0.4, 18)));
418 assert_eq!(iter.next(), Some((0.2, 16)));
419 assert_eq!(iter.next(), Some((0.1, 15)));
420 assert_eq!(iter.next(), None);
421 }
422
423 #[test]
424 fn test_len() {
425 let mut rw = RouletteWheel::<u8>::new();
426
427 assert_eq!(rw.len(), 0);
428
429 rw.push(0.1, 1);
430 rw.push(0.1, 1);
431 rw.push(0.1, 1);
432 rw.push(0.1, 1);
433
434 assert_eq!(rw.len(), 4);
435 }
436}