roulette_wheel/
lib.rs

1//! A Little implementation of the roulette-wheel principle, `RouletteWheel<T>`.
2//! https://wikipedia.org/wiki/Fitness_proportionate_selection
3//!
4//! ![Fitness proportionate selection](https://upload.wikimedia.org/wikipedia/commons/2/2a/Fitness_proportionate_selection_example.png)
5//!
6//! # Examples usages
7//!
8//! ```
9//! use roulette_wheel::RouletteWheel;
10//!
11//! fn evaluate(individual: &i32) -> f32 { *individual as f32 } // mmm...!
12//!
13//! let population: Vec<_> = (1..10).into_iter().collect();
14//! let fitnesses: Vec<_> = population.iter().map(|ind| evaluate(ind)).collect();
15//!
16//! let rw: RouletteWheel<_> = fitnesses.into_iter().zip(population).collect();
17//!
18//! // let's collect the individuals in the order in which the roulette wheel gives them
19//! let individuals: Vec<_> = rw.into_iter().map(|(_, ind)| ind).collect();
20//! // rw.select_iter() will not consume the roulette wheel
21//! // while rw.into_iter() will !
22//!
23//! fn crossover(mother: &i32, _father: &i32) -> i32 { mother.clone() } // unimplemented!()
24//!
25//! // now merge each individual by couples
26//! let new_population: Vec<_> = individuals.chunks(2)
27//!                                  .filter(|couple| couple.len() == 2)
28//!                                  .map(|couple| {
29//!                                       let (mother, father) = (couple[0], couple[1]);
30//!                                       crossover(&mother, &father)
31//!                                       // note: for this example we return only one individual,
32//!                                       //       the population will shrink
33//!                                       //       .flat_map() can resolve this issue
34//!                                   }).collect();
35//! ```
36
37extern crate rand;
38
39use std::iter::{FromIterator, Iterator, IntoIterator};
40use rand::{Rng, ThreadRng, thread_rng};
41use rand::distributions::{Range, IndependentSample};
42
43/// A roulette-wheel container
44pub struct RouletteWheel<T> {
45    total_fitness: f32,
46    fitnesses: Vec<f32>,
47    population: Vec<T>
48}
49
50impl<T: Clone> Clone for RouletteWheel<T> {
51    fn clone(&self) -> RouletteWheel<T> {
52        RouletteWheel {
53            total_fitness: self.total_fitness,
54            fitnesses: self.fitnesses.clone(),
55            population: self.population.clone()
56        }
57    }
58}
59
60impl<T> FromIterator<(f32, T)> for RouletteWheel<T> {
61    fn from_iter<A>(iter: A) -> Self where A: IntoIterator<Item=(f32, T)> {
62        let (fitnesses, population): (Vec<f32>, _) = iter.into_iter().unzip();
63        let total_fitness = fitnesses.iter().sum();
64        RouletteWheel {
65            total_fitness: total_fitness,
66            fitnesses: fitnesses,
67            population: population
68        }
69    }
70}
71
72impl<T> RouletteWheel<T> {
73    /// create a new empty random-wheel.
74    /// # Example
75    ///
76    /// ```
77    /// use roulette_wheel::RouletteWheel;
78    ///
79    /// let rw = RouletteWheel::<u8>::new();
80    /// ```
81    pub fn new() -> RouletteWheel<T> {
82        RouletteWheel {
83            total_fitness: 0.0,
84            fitnesses: Vec::new(),
85            population: Vec::new()
86        }
87    }
88
89    /// Creates an empty RouletteWheel with space for at least n elements.
90    /// # Example
91    ///
92    /// ```
93    /// use roulette_wheel::RouletteWheel;
94    ///
95    /// let rw = RouletteWheel::<u8>::with_capacity(15);
96    ///
97    /// assert_eq!(rw.len(), 0);
98    /// ```
99    pub fn with_capacity(cap: usize) -> RouletteWheel<T> {
100        RouletteWheel {
101            total_fitness: 0.0,
102            fitnesses: Vec::with_capacity(cap),
103            population: Vec::with_capacity(cap)
104        }
105    }
106
107    /// Reserves capacity for at least `additional` more elements to be inserted.
108    /// The collection may reserve more space to avoid frequent reallocations.
109    /// # Example
110    ///
111    /// ```
112    /// use roulette_wheel::RouletteWheel;
113    ///
114    /// let mut rw = RouletteWheel::<u8>::new();
115    /// rw.reserve(20);
116    ///
117    /// assert_eq!(rw.len(), 0);
118    /// ```
119    pub fn reserve(&mut self, additional: usize) {
120        self.fitnesses.reserve(additional);
121        self.population.reserve(additional);
122    }
123
124    /// Returns the number of elements in the wheel.
125    /// # Example
126    ///
127    /// ```
128    /// use roulette_wheel::RouletteWheel;
129    ///
130    /// let rw: RouletteWheel<_> = [(0.1, 10), (0.2, 15), (0.5, 20)].iter().cloned().collect();
131    ///
132    /// assert_eq!(rw.len(), 3);
133    /// ```
134    pub fn len(&self) -> usize {
135        self.population.len()
136    }
137
138    /// Returns `true` if empty else return `false`.
139    /// # Example
140    ///
141    /// ```
142    /// use roulette_wheel::RouletteWheel;
143    ///
144    /// let empty_rw = RouletteWheel::<u8>::new();
145    ///
146    /// assert_eq!(empty_rw.is_empty(), true);
147    ///
148    /// let non_empty_rw: RouletteWheel<_> = [(0.1, 10), (0.2, 15), (0.5, 20)].iter().cloned().collect();
149    ///
150    /// assert_eq!(non_empty_rw.is_empty(), false);
151    /// ```
152    pub fn is_empty(&self) -> bool {
153        self.population.is_empty()
154    }
155
156    /// Remove all elements in this wheel.
157    /// # Example
158    ///
159    /// ```
160    /// use roulette_wheel::RouletteWheel;
161    ///
162    /// let mut rw: RouletteWheel<_> = [(0.1, 10), (0.2, 15), (0.5, 20)].iter().cloned().collect();
163    ///
164    /// assert_eq!(rw.len(), 3);
165    ///
166    /// rw.clear();
167    ///
168    /// assert_eq!(rw.len(), 0);
169    /// ```
170    pub fn clear(&mut self) {
171        self.fitnesses.clear();
172        self.population.clear();
173    }
174
175    /// Add an element associated with a probability.
176    ///
177    /// # Panics
178    ///
179    /// This function might panic if the fitness is less than zero
180    /// or if the total fitness gives a non-finite fitness (`Inf`).
181    ///
182    /// # Example
183    ///
184    /// ```
185    /// use roulette_wheel::RouletteWheel;
186    ///
187    /// let mut rw = RouletteWheel::new();
188    ///
189    /// rw.push(1.0, 'r');
190    /// rw.push(1.0, 'c');
191    /// rw.push(1.0, 'a');
192    ///
193    /// assert_eq!(rw.len(), 3);
194    /// ```
195    pub fn push(&mut self, fitness: f32, individual: T) {
196        assert!(fitness >= 0.0, "Can't push the less than zero fitness: {:?}", fitness);
197        assert!((self.total_fitness + fitness).is_finite(), "Fitnesses sum reached a non-finite value!");
198        unsafe { self.unchecked_push(fitness, individual) }
199    }
200
201    /// Add an element associated with a probability.
202    /// This unsafe function doesn't check for total fitness overflow
203    /// nether fitness positivity.
204    /// # Example
205    ///
206    /// ```
207    /// use roulette_wheel::RouletteWheel;
208    ///
209    /// let mut rw = RouletteWheel::new();
210    ///
211    /// unsafe { rw.unchecked_push(1.0, 'r') };
212    /// unsafe { rw.unchecked_push(1.0, 'c') };
213    /// unsafe { rw.unchecked_push(1.0, 'a') };
214    ///
215    /// assert_eq!(rw.len(), 3);
216    /// ```
217    pub unsafe fn unchecked_push(&mut self, fitness: f32, individual: T) {
218        self.total_fitness += fitness;
219        self.fitnesses.push(fitness);
220        self.population.push(individual);
221    }
222
223    /// Returns the sum of all individual fitnesses.
224    /// # Example
225    ///
226    /// ```
227    /// use roulette_wheel::RouletteWheel;
228    ///
229    /// let mut rw = RouletteWheel::new();
230    ///
231    /// rw.push(3.0, 'r');
232    /// rw.push(2.0, 'c');
233    /// rw.push(1.5, 'a');
234    ///
235    /// assert_eq!(rw.total_fitness(), 6.5);
236    /// ```
237    pub fn total_fitness(&self) -> f32 {
238        self.total_fitness
239    }
240
241    /// Returns an iterator over the RouletteWheel.
242    ///
243    /// # Examples
244    ///
245    /// ``` ignore
246    /// use roulette_wheel::RouletteWheel;
247    ///
248    /// let rw: RouletteWheel<_> = [(0.1, 10), (0.2, 15), (0.5, 20)].iter().cloned().collect();
249    /// let mut iterator = rw.select_iter();
250    ///
251    /// assert_eq!(iterator.next(), Some((0.5, &20)));
252    /// assert_eq!(iterator.next(), Some((0.1, &10)));
253    /// assert_eq!(iterator.next(), Some((0.2, &15)));
254    /// assert_eq!(iterator.next(), None);
255    /// ```
256    pub fn select_iter(&self) -> SelectIter<ThreadRng, T> {
257        SelectIter::<ThreadRng, _>::new(&self)
258    }
259}
260
261/// Immutable RouletteWheel iterator
262///
263/// This struct is created by the [`select_iter`].
264///
265/// [`iter`]: struct.RouletteWheel.html#method.select_iter
266pub struct SelectIter<'a, R: Rng, T: 'a> {
267    distribution_range: Range<f32>,
268    rng: R,
269    total_fitness: f32,
270    fitnesses_ids: Vec<(usize, f32)>,
271    roulette_wheel: &'a RouletteWheel<T>
272}
273
274impl<'a, R: Rng, T> SelectIter<'a, R, T> {
275    pub fn new(roulette_wheel: &'a RouletteWheel<T>) -> SelectIter<'a, ThreadRng, T> {
276        SelectIter::from_rng(roulette_wheel, thread_rng())
277    }
278
279    pub fn from_rng(roulette_wheel: &'a RouletteWheel<T>, rng: R) -> SelectIter<'a, R, T> {
280        SelectIter {
281            distribution_range: Range::new(0.0, 1.0),
282            rng: rng,
283            total_fitness: roulette_wheel.total_fitness,
284            fitnesses_ids: roulette_wheel.fitnesses.iter().cloned().enumerate().collect(),
285            roulette_wheel: roulette_wheel
286        }
287    }
288}
289
290impl<'a, R: Rng, T: 'a> Iterator for SelectIter<'a, R, T> {
291    type Item = (f32, &'a T);
292
293    fn size_hint(&self) -> (usize, Option<usize>) {
294        (self.fitnesses_ids.len(), Some(self.fitnesses_ids.len()))
295    }
296
297    fn next(&mut self) -> Option<Self::Item> {
298        if !self.fitnesses_ids.is_empty() {
299            let sample = self.distribution_range.ind_sample(&mut self.rng);
300            let mut selection = sample * self.total_fitness;
301            let index = self.fitnesses_ids.iter().position(|&(_, fit)| {
302                            selection -= fit;
303                            selection <= 0.0
304                        }).unwrap();
305            let (index, fitness) = self.fitnesses_ids.swap_remove(index);
306            self.total_fitness -= fitness;
307            Some((fitness, &self.roulette_wheel.population[index]))
308        }
309        else { None }
310    }
311}
312
313impl<T> IntoIterator for RouletteWheel<T> {
314    type Item = (f32, T);
315    type IntoIter = IntoSelectIter<ThreadRng, T>;
316
317    fn into_iter(self) -> IntoSelectIter<ThreadRng, T> {
318        IntoSelectIter::<ThreadRng, _>::new(self)
319    }
320}
321
322/// An iterator that moves out of a RouletteWheel.
323///
324/// This `struct` is created by the `into_iter` method on [`RouletteWheel`][`RouletteWheel`] (provided
325/// by the [`IntoIterator`] trait).
326///
327/// [`RouletteWheel`]: struct.RouletteWheel.html
328/// [`IntoIterator`]: https://doc.rust-lang.org/std/iter/trait.IntoIterator.html
329pub struct IntoSelectIter<R: Rng, T> {
330    distribution_range: Range<f32>,
331    rng: R,
332    total_fitness: f32,
333    fitnesses: Vec<f32>,
334    population: Vec<T>
335}
336
337impl<R: Rng, T> IntoSelectIter<R, T> {
338    pub fn new(roulette_wheel: RouletteWheel<T>) -> IntoSelectIter<ThreadRng, T> {
339        IntoSelectIter::from_rng(roulette_wheel, thread_rng())
340    }
341
342    pub fn from_rng(roulette_wheel: RouletteWheel<T>, rng: R) -> IntoSelectIter<R, T> {
343        IntoSelectIter {
344            distribution_range: Range::new(0.0, 1.0),
345            rng: rng,
346            total_fitness: roulette_wheel.total_fitness,
347            fitnesses: roulette_wheel.fitnesses,
348            population: roulette_wheel.population
349        }
350    }
351}
352
353impl<R: Rng, T> Iterator for IntoSelectIter<R, T> {
354    type Item = (f32, T);
355
356    fn size_hint(&self) -> (usize, Option<usize>) {
357        (self.fitnesses.len(), Some(self.fitnesses.len()))
358    }
359
360    fn next(&mut self) -> Option<Self::Item> {
361        if !self.fitnesses.is_empty() {
362            let sample = self.distribution_range.ind_sample(&mut self.rng);
363            let mut selection = sample * self.total_fitness;
364            let index = self.fitnesses.iter().position(|fit| {
365                            selection -= *fit;
366                            selection <= 0.0
367                        }).unwrap();
368            let fitness = self.fitnesses.swap_remove(index);
369            let individual = self.population.swap_remove(index);
370            self.total_fitness -= fitness;
371            Some((fitness, individual))
372        }
373        else { None }
374    }
375}
376
377#[cfg(test)]
378mod tests {
379    use rand::SeedableRng;
380    use rand::StdRng;
381    use {RouletteWheel, SelectIter, IntoSelectIter};
382
383    const SEED: [usize; 4] = [4, 2, 42, 4242];
384
385    #[test]
386    fn test_select_iter_seeded() {
387        let rng = StdRng::from_seed(&SEED);
388
389        let fitnesses = [0.1, 0.2, 0.3, 0.4, 0.5];
390        let fitnesses = fitnesses.iter().cloned();
391        let population = 15..20;
392        let rw: RouletteWheel<_> = fitnesses.zip(population).collect();
393
394        let mut iter = SelectIter::from_rng(&rw, rng);
395
396        assert_eq!(iter.next(), Some((0.5, &19)));
397        assert_eq!(iter.next(), Some((0.3, &17)));
398        assert_eq!(iter.next(), Some((0.4, &18)));
399        assert_eq!(iter.next(), Some((0.2, &16)));
400        assert_eq!(iter.next(), Some((0.1, &15)));
401        assert_eq!(iter.next(), None);
402    }
403
404    #[test]
405    fn test_into_select_iter_seeded() {
406        let rng = StdRng::from_seed(&SEED);
407
408        let fitnesses = [0.1, 0.2, 0.3, 0.4, 0.5];
409        let fitnesses = fitnesses.iter().cloned();
410        let population = 15..20;
411        let rw: RouletteWheel<_> = fitnesses.zip(population).collect();
412
413        let mut iter = IntoSelectIter::from_rng(rw, rng);
414
415        assert_eq!(iter.next(), Some((0.5, 19)));
416        assert_eq!(iter.next(), Some((0.3, 17)));
417        assert_eq!(iter.next(), Some((0.4, 18)));
418        assert_eq!(iter.next(), Some((0.2, 16)));
419        assert_eq!(iter.next(), Some((0.1, 15)));
420        assert_eq!(iter.next(), None);
421    }
422
423    #[test]
424    fn test_len() {
425        let mut rw = RouletteWheel::<u8>::new();
426
427        assert_eq!(rw.len(), 0);
428
429        rw.push(0.1, 1);
430        rw.push(0.1, 1);
431        rw.push(0.1, 1);
432        rw.push(0.1, 1);
433
434        assert_eq!(rw.len(), 4);
435    }
436}