neural/
lib.rs

1//! # Neural
2//!
3//! **Neural** is a library for Genetic Algorithms in Rust.
4//!
5//! ## Concepts
6//!
7//! - [**Genetic Algorithm**](https://en.wikipedia.org/wiki/Genetic_algorithm) - A algorithm for solving optimization problems by simulating the process of natural selection, evolution, mutation and crossover (reproduction).
8//! - **Gene** - A single value that represents a possible solution to a problem.
9//! - **Chromosome** - A collection of genes that represent a possible solution to a problem.
10//! - **Population** - A collection of chromosomes that represent potential solutions to a problem.
11//! - **Selection** - A trait for selecting a subset of the population.
12//! - **Crossover** - A trait for crossing over two chromosomes.
13//!
14//! ## Features
15//!
16//! - **print** - Allows using the `with_print` method on a `PopulationBuilder` to print the population's best chromosome after each generation and enables colored output.
17//!
18//! ## Getting Started
19//!
20//! ```sh
21//! cargo add neural
22//!
23//! # or add it with the print feature
24//! cargo add neural --features print
25//! ```
26//!
27//! Or add it as a dependency in your `Cargo.toml`:
28//!
29//! ```toml
30//! [dependencies]
31//! neural = "^0.3"
32//!
33//! # or add it with the print feature
34//! [dependencies]
35//! neural = { version = "^0.3", features = ["print"] }
36//! ```
37//!
38//! ## Usage
39//!
40//! ```rust
41//! use neural::{Gene, Population, PopulationBuilder, Result, TournamentSelection, UniformCrossover};
42//! use rand::{rngs::ThreadRng, Rng};
43//! use std::fmt::Display;
44//!
45//! #[derive(Debug, Clone, PartialEq, PartialOrd)]
46//! struct F64(f64);
47//!
48//! impl Eq for F64 {}
49//!
50//! impl Display for F64 {
51//!     fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
52//!         write!(f, "{}", self.0)
53//!     }
54//! }
55//!
56//! impl From<f64> for F64 {
57//!     fn from(value: f64) -> Self {
58//!         Self(value)
59//!     }
60//! }
61//!
62//! impl From<F64> for f64 {
63//!     fn from(val: F64) -> Self {
64//!         val.0
65//!     }
66//! }
67//!
68//! impl Gene for F64 {
69//!     fn generate_gene<R>(rng: &mut R) -> Self
70//!     where
71//!         R: Rng + ?Sized,
72//!     {
73//!         rng.random_range(-1.0..=1.0).into()
74//!     }
75//! }
76//!
77//! fn main() -> Result<()> {
78//!     let mut pop: Population<F64, TournamentSelection, UniformCrossover, _, ThreadRng> =
79//!         PopulationBuilder::new(None, |c| c.value.iter().map(|g: &F64| g.0).sum::<f64>())
80//!             .with_chromo_size(50)
81//!             .with_population_size(100)
82//!             .with_mutation_rate(0.02)
83//!             .with_elitism(true)
84//!             // .with_print(true) // uncomment to print the best chromosome after each generation. requires the print feature
85//!             .build()?;
86//!
87//!     let num_generations = 200;
88//!     match pop.evolve(num_generations) {
89//!         Some(best) => {
90//!             println!("Evolution complete after {num_generations} generations.");
91//!             println!(
92//!                 "Best fitness found: {}, chromosome: {:?}",
93//!                 best.fitness, best.value
94//!             );
95//!         }
96//!         None => println!("Evolution ended with an empty population."),
97//!     }
98//!
99//!     Ok(())
100//! }
101//! ```
102
103#![forbid(unsafe_code)]
104#![warn(clippy::pedantic, missing_debug_implementations)]
105
106mod crossover;
107mod errors;
108mod selection;
109
110pub use crossover::{Crossover, UniformCrossover};
111pub use errors::{NeuralError, Result};
112use rand::{seq::IndexedRandom, Rng};
113pub use selection::{RouletteWheelSelection, Selection, TournamentSelection};
114use std::cmp::Ordering;
115
116#[cfg(feature = "print")]
117use colored::Colorize;
118
119/// Trait for generating a random Gene
120pub trait Gene: Clone {
121    /// Returns a random [`Gene`]
122    ///
123    /// ## Arguments
124    ///
125    /// * `rng` - The random number generator
126    fn generate_gene<R>(rng: &mut R) -> Self
127    where
128        R: Rng + ?Sized;
129}
130
131/// Represents a Chromosome with [Genes](Gene)
132#[derive(Debug, Clone, PartialEq)]
133pub struct Chromosome<G> {
134    pub value: Vec<G>,
135    pub fitness: f64,
136}
137
138impl<G> Chromosome<G> {
139    /// Creates a new [Chromosome]
140    #[must_use]
141    pub fn new(value: Vec<G>) -> Self {
142        Self {
143            value,
144            fitness: 0.0,
145        }
146    }
147}
148
149impl<G> Eq for Chromosome<G> where G: PartialEq {}
150
151impl<G> PartialOrd for Chromosome<G>
152where
153    G: PartialEq,
154{
155    fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
156        Some(self.cmp(other))
157    }
158}
159
160impl<G> Ord for Chromosome<G>
161where
162    G: PartialEq,
163{
164    fn cmp(&self, other: &Self) -> Ordering {
165        self.fitness
166            .partial_cmp(&other.fitness)
167            .unwrap_or(Ordering::Equal)
168    }
169}
170
171/// Represents a Population of [Chromosomes](Chromosome)
172#[derive(Debug)]
173pub struct Population<G, S, C, F, R>
174where
175    R: Rng + ?Sized,
176{
177    pub chromo_size: usize,
178    pub pop_size: usize,
179    pub mut_rate: f64,
180    pub population: Vec<Chromosome<G>>,
181    pub eval_fn: F,
182    pub selection: S,
183    pub crossover: C,
184    pub elitism: bool,
185    pub rng: Box<R>,
186
187    #[cfg(feature = "print")]
188    pub print: bool,
189}
190
191impl<G, S, C, F, R> Population<G, S, C, F, R>
192where
193    G: Gene,
194    S: Selection<G>,
195    C: Crossover<G>,
196    F: FnMut(&Chromosome<G>) -> f64,
197    R: Rng + ?Sized,
198{
199    /// Returns the best [Chromosome], the chromosome with the highest fitness
200    #[must_use]
201    pub fn best(&self) -> Option<&Chromosome<G>> {
202        let mut best_fitness = 0.0;
203        let mut best_match = None;
204
205        for (i, c) in self.population.iter().enumerate() {
206            if c.fitness > best_fitness {
207                best_fitness = c.fitness;
208                best_match = Some(i);
209            }
210        }
211
212        match best_match {
213            Some(i) => Some(&self.population[i]),
214            None => None,
215        }
216    }
217
218    /// Evaluates the chromosomes in the population
219    pub fn evaluate(&mut self) {
220        self.population.iter_mut().for_each(|c| {
221            c.fitness = (self.eval_fn)(c);
222        });
223    }
224
225    /// Evolves the [Population] and returns the best chromosome
226    ///
227    /// ## Arguments
228    ///
229    /// * `generations` - The number of generations to evolve the [Population] for
230    /// * `rng` - The random number generator
231    ///
232    /// ## Returns
233    ///
234    /// The best [Chromosome] in the [Population]
235    #[allow(clippy::used_underscore_binding)]
236    pub fn evolve(&mut self, generations: u32) -> Option<Chromosome<G>> {
237        if self.population.is_empty() {
238            return None;
239        }
240
241        let elitism_offset = usize::from(self.elitism);
242        for _gen in 0..generations {
243            if self.population.is_empty() {
244                #[cfg(feature = "print")]
245                if self.print {
246                    println!("Population collapsed at generation {_gen}");
247                }
248                return None;
249            }
250
251            let mut next_gen = Vec::with_capacity(self.pop_size);
252
253            if self.elitism {
254                if let Some(best) = self.best() {
255                    next_gen.push(best.clone());
256                }
257            }
258
259            let fill_count = self.pop_size - next_gen.len();
260            for _ in elitism_offset..fill_count {
261                let parents = self.selection.select(&self.population, 2, &mut self.rng);
262                if parents.len() < 2 {
263                    if let Some(ind) = self.population.choose(&mut self.rng) {
264                        next_gen.push(ind.clone());
265                    } else {
266                        continue;
267                    }
268                    continue;
269                }
270
271                if let Some(offspring) =
272                    self.crossover
273                        .crossover(parents[0], parents[1], &mut self.rng)
274                {
275                    next_gen.push(offspring);
276                }
277            }
278
279            self.population = next_gen;
280            self.mutate();
281            self.evaluate();
282
283            #[cfg(feature = "print")]
284            if self.print {
285                if let Some(best) = self.best() {
286                    println!(
287                        "Generation: {}: Best fitness = {}",
288                        _gen.to_string().cyan().bold(),
289                        best.fitness.to_string().cyan().bold()
290                    );
291                }
292            }
293        }
294
295        self.best().cloned()
296    }
297
298    /// Mutates the chromosomes in the population
299    ///
300    /// ## Arguments
301    ///
302    /// * `rng` - The random number generator
303    pub fn mutate(&mut self) {
304        self.population.iter_mut().for_each(|c| {
305            for g in &mut c.value {
306                if self.rng.random_bool(self.mut_rate) {
307                    *g = G::generate_gene(&mut self.rng);
308                }
309            }
310        });
311    }
312
313    /// Returns the worst [Chromosome], the chromosome with the lowest fitness
314    #[must_use]
315    pub fn worst(&self) -> Option<&Chromosome<G>> {
316        if self.population.is_empty() {
317            return None;
318        }
319
320        match self.worst_index() {
321            Some(i) => Some(&self.population[i]),
322            None => None,
323        }
324    }
325
326    /// Returns the index of the worst [Chromosome], the chromosome with the lowest fitness
327    #[must_use]
328    pub fn worst_index(&self) -> Option<usize> {
329        if self.population.is_empty() {
330            return None;
331        }
332
333        let mut best_fitness = self.population[0].fitness;
334        let mut best_match = None;
335
336        for (i, c) in self.population.iter().enumerate().skip(1) {
337            if c.fitness < best_fitness {
338                best_fitness = c.fitness;
339                best_match = Some(i);
340            }
341        }
342
343        best_match
344    }
345}
346
347/// Builder for a [Population]
348#[derive(Debug)]
349pub struct PopulationBuilder<G, S, C, F, R> {
350    chromo_size: usize,
351    pop_size: usize,
352    mut_rate: f64,
353    population: Option<Vec<Chromosome<G>>>,
354    eval_fn: F,
355    selection: S,
356    crossover: C,
357    elitism: bool,
358    rng: R,
359
360    #[cfg(feature = "print")]
361    print: bool,
362}
363
364impl<G, S, C, F, R> PopulationBuilder<G, S, C, F, R>
365where
366    G: Gene,
367    S: Selection<G>,
368    C: Crossover<G>,
369    F: FnMut(&Chromosome<G>) -> f64,
370    R: Rng + Default,
371{
372    /// Creates a new [`PopulationBuilder`]
373    ///
374    /// ## Arguments
375    ///
376    /// * `population` - The population to use
377    /// * `eval_fn` - The evaluation function
378    ///
379    /// ## Returns
380    ///
381    /// A new [`PopulationBuilder`]
382    #[must_use]
383    pub fn new(population: Option<Vec<Chromosome<G>>>, eval_fn: F) -> Self {
384        Self {
385            chromo_size: 10,
386            pop_size: 10,
387            mut_rate: 0.015,
388            population,
389            eval_fn,
390            selection: S::default(),
391            crossover: C::default(),
392            elitism: false,
393            rng: R::default(),
394
395            #[cfg(feature = "print")]
396            print: false,
397        }
398    }
399
400    /// Sets the chromosome size, how many genes
401    ///
402    /// ## Arguments
403    ///
404    /// * `chromo_size` - The chromosome size
405    #[must_use]
406    pub fn with_chromo_size(mut self, chromo_size: usize) -> Self {
407        self.chromo_size = chromo_size;
408        self
409    }
410
411    /// Sets the population size
412    ///
413    /// ## Arguments
414    ///
415    /// * `pop_size` - The population size
416    #[must_use]
417    pub fn with_population_size(mut self, pop_size: usize) -> Self {
418        self.pop_size = pop_size;
419        self
420    }
421
422    /// Sets the mutation rate
423    ///
424    /// ## Arguments
425    ///
426    /// * `mutation_rate` - The mutation rate
427    #[must_use]
428    pub fn with_mutation_rate(mut self, mutation_rate: f64) -> Self {
429        self.mut_rate = mutation_rate;
430        self
431    }
432
433    /// Sets the elitism flag
434    ///
435    /// ## Arguments
436    ///
437    /// * `elitism` - Whether or not to use elitism (keep the best chromosome)
438    #[must_use]
439    pub fn with_elitism(mut self, elitism: bool) -> Self {
440        self.elitism = elitism;
441        self
442    }
443
444    /// Sets the print flag
445    ///
446    /// ## Arguments
447    ///
448    /// * `print` - The print flag
449    #[must_use]
450    #[cfg(feature = "print")]
451    pub fn with_print(mut self, print: bool) -> Self {
452        self.print = print;
453        self
454    }
455
456    /// Sets the random number generator
457    ///
458    /// ## Arguments
459    ///
460    /// * `rng` - The random number generator to use
461    #[must_use]
462    pub fn with_rng(mut self, rng: R) -> Self {
463        self.rng = rng;
464        self
465    }
466
467    /// Sets the selection method
468    ///
469    /// ## Arguments
470    ///
471    /// * `selection` - The selection method
472    #[must_use]
473    pub fn with_selection(mut self, selection: S) -> Self {
474        self.selection = selection;
475        self
476    }
477
478    /// Sets the crossover method
479    ///
480    /// ## Arguments
481    ///
482    /// * `crossover` - The crossover method
483    #[must_use]
484    pub fn with_crossover(mut self, crossover: C) -> Self {
485        self.crossover = crossover;
486        self
487    }
488
489    /// Builds the [Population]
490    ///
491    /// ## Errors
492    ///
493    /// - [`NeuralError::NoPopSize`]: If the population size is not set
494    /// - [`NeuralError::NoChromoSize`]: If the chromosome size is not set
495    /// - [`NeuralError::NoMutationRate`]: If the mutation rate is not set
496    pub fn build(self) -> Result<Population<G, S, C, F, R>> {
497        let mut n = Population {
498            chromo_size: self.chromo_size,
499            pop_size: self.pop_size,
500            mut_rate: self.mut_rate,
501            population: Vec::new(),
502            eval_fn: self.eval_fn,
503            selection: self.selection,
504            crossover: self.crossover,
505            elitism: self.elitism,
506            rng: Box::new(self.rng),
507
508            #[cfg(feature = "print")]
509            print: self.print,
510        };
511
512        if let Some(pop) = self.population {
513            n.population = pop;
514        } else {
515            let mut pop = Vec::with_capacity(n.pop_size);
516            for _ in 0..n.pop_size {
517                pop.push(Chromosome {
518                    value: (0..n.chromo_size)
519                        .map(|_| G::generate_gene(&mut n.rng))
520                        .collect(),
521                    fitness: 0.0,
522                });
523            }
524            n.population = pop;
525        }
526
527        if n.chromo_size == 0 {
528            return Err(NeuralError::NoChromoSize);
529        }
530
531        if n.pop_size == 0 {
532            return Err(NeuralError::NoPopSize);
533        }
534
535        if !(0.0..=1.0).contains(&n.mut_rate) {
536            return Err(NeuralError::NoMutationRate);
537        }
538
539        n.evaluate();
540
541        Ok(n)
542    }
543}
544
545#[cfg(test)]
546mod tests {
547    use super::*;
548    use rand::rngs::ThreadRng;
549
550    macro_rules! generate_selection_test {
551        ($name:ident, $selection:ty, $crossover:ty) => {
552            #[test]
553            fn $name() -> Result<()> {
554                #[derive(Debug, Clone, PartialEq, PartialOrd)]
555                struct F64(f64);
556
557                impl Eq for F64 {}
558                impl From<f64> for F64 {
559                    fn from(value: f64) -> Self {
560                        Self(value)
561                    }
562                }
563
564                impl Gene for F64 {
565                    fn generate_gene<R>(rng: &mut R) -> Self
566                    where
567                        R: Rng + ?Sized,
568                    {
569                        rng.random_range(-1.0..=1.0).into()
570                    }
571                }
572
573                let mut pop: Population<F64, $selection, $crossover, _, ThreadRng> =
574                    PopulationBuilder::new(None, |c| {
575                        c.value.iter().map(|g: &F64| g.0).sum::<f64>()
576                    })
577                    .with_chromo_size(50)
578                    .with_population_size(100)
579                    .with_mutation_rate(0.02)
580                    .build()?;
581                let num_generations = 200;
582
583                pop.evolve(num_generations);
584                Ok(())
585            }
586        };
587    }
588
589    generate_selection_test!(
590        test_population_roulette_wheel_uniform,
591        RouletteWheelSelection,
592        UniformCrossover
593    );
594    generate_selection_test!(
595        test_population_tournament_uniform,
596        TournamentSelection,
597        UniformCrossover
598    );
599
600    #[test]
601    #[should_panic(expected = "NoPopSize")]
602    #[allow(non_local_definitions)]
603    fn test_no_pop_size() {
604        impl Gene for i32 {
605            fn generate_gene<R>(rng: &mut R) -> Self
606            where
607                R: Rng + ?Sized,
608            {
609                rng.random()
610            }
611        }
612
613        PopulationBuilder::<i32, TournamentSelection, UniformCrossover, _, ThreadRng>::new(
614            None,
615            |c| f64::from(c.value.iter().sum::<i32>()),
616        )
617        .with_population_size(0)
618        .build()
619        .unwrap();
620    }
621}