magnesia/optimize/
differential_evolution.rs

1use std::{
2    array::from_fn,
3    ops::{Add, Mul, Sub},
4};
5
6/// Data structure implementing the differential evolution algorithm.
7///
8/// # Example
9/// ```
10/// use magnesia::optimize::test_functions::ackley;
11/// use magnesia::optimize::differential_evolution;
12///
13/// let mut rng = rand::thread_rng();
14/// let loc = differential_evolution(
15///     &[(-5.0, 5.0); 2],
16///     20,
17///     0.9,
18///     0.8,
19///     500,
20///     &mut rng,
21///     |lhs, rhs| ackley(lhs[0], lhs[1]) < ackley(rhs[0], rhs[1]),
22/// );
23/// assert!(loc[0].abs() < 0.1);
24/// assert!(loc[1].abs() < 0.1);
25/// ```
26pub fn optimize<T, const N: usize>(
27    bounds: &[(T, T); N],
28    population_size: usize,
29    crossover_probability: f32,
30    differential_weight: T,
31    num_iters: usize,
32    rng: &mut impl rand::Rng,
33    mut compare_candidates: impl FnMut(&[T; N], &[T; N]) -> bool,
34) -> [T; N]
35where
36    T: PartialOrd
37        + Clone
38        + Copy
39        + Add<T, Output = T>
40        + Sub<T, Output = T>
41        + Mul<T, Output = T>
42        + rand::distributions::uniform::SampleUniform,
43    f32: Into<T>,
44    i8: Into<T>,
45{
46    assert!(
47        population_size >= 4,
48        "The population must have at least 4 elements"
49    );
50    assert!(
51        (0.0..=1.0).contains(&crossover_probability),
52        "Invalid crossover probability"
53    );
54    assert!(
55        (0.into()..=2.into()).contains(&differential_weight),
56        "Invalid differential weight"
57    );
58
59    // Create initial candidate set
60    let mut population = (0..population_size)
61        .map(|_| from_fn(|i| rng.gen_range(bounds[i].0..bounds[i].1)))
62        .collect::<Box<[[T; N]]>>();
63
64    // Greedily improve solution
65    let index_range = 0..population.len();
66    for i in index_range.clone().cycle().take(num_iters) {
67        // Generate random distinct indices `j`, `k` and `l`
68        let mut j;
69        loop {
70            j = rng.gen_range(index_range.clone());
71            if j != i {
72                break;
73            }
74        }
75        let mut k;
76        loop {
77            k = rng.gen_range(index_range.clone());
78            if ![i, j].contains(&k) {
79                break;
80            }
81        }
82        let mut l;
83        loop {
84            l = rng.gen_range(index_range.clone());
85            if ![i, j, k].contains(&l) {
86                break;
87            }
88        }
89        // Freeze j, k and l
90        let (j, k, l) = (j, k, l);
91
92        // Generate new candidate
93        let r = rng.gen_range(0..N);
94        let x = from_fn(|n| {
95            if n == r || rng.gen::<f32>() < crossover_probability {
96                population[j][n] + (population[k][n] - population[l][n]) * differential_weight
97            } else {
98                population[i][n]
99            }
100        });
101
102        // If the candidate is better than population[i], replace
103        if compare_candidates(&x, &population[i]) {
104            population[i] = x;
105        }
106    }
107
108    // knockout stage
109    let mut population = population.as_mut();
110    while population.len() > 1 {
111        let pop_len = population.len();
112        let half_len = pop_len / 2;
113        for i in 0..half_len {
114            if compare_candidates(&population[half_len + i], &population[i]) {
115                population[i] = population[half_len + i];
116            }
117            if pop_len % 2 == 1 {
118                population[half_len] = population[pop_len - 1];
119            }
120        }
121        population = &mut population[0..((pop_len + 1) / 2)];
122    }
123
124    // Return the winner
125    population[0]
126}
127
128#[test]
129fn test_optimize_ackley() {
130    use crate::optimize::test_functions::ackley;
131
132    let mut rng = rand::thread_rng();
133    for _ in 0..10 {
134        let loc = optimize(
135            &[(-5.0, 5.0); 2],
136            20,
137            0.9,
138            0.8,
139            500,
140            &mut rng,
141            |lhs, rhs| ackley(lhs[0], lhs[1]) < ackley(rhs[0], rhs[1]),
142        );
143        assert!(loc[0].abs() < 0.1);
144        assert!(loc[1].abs() < 0.1);
145    }
146}