magnesia/optimize/
differential_evolution.rs1use std::{
2 array::from_fn,
3 ops::{Add, Mul, Sub},
4};
5
6pub fn optimize<T, const N: usize>(
27 bounds: &[(T, T); N],
28 population_size: usize,
29 crossover_probability: f32,
30 differential_weight: T,
31 num_iters: usize,
32 rng: &mut impl rand::Rng,
33 mut compare_candidates: impl FnMut(&[T; N], &[T; N]) -> bool,
34) -> [T; N]
35where
36 T: PartialOrd
37 + Clone
38 + Copy
39 + Add<T, Output = T>
40 + Sub<T, Output = T>
41 + Mul<T, Output = T>
42 + rand::distributions::uniform::SampleUniform,
43 f32: Into<T>,
44 i8: Into<T>,
45{
46 assert!(
47 population_size >= 4,
48 "The population must have at least 4 elements"
49 );
50 assert!(
51 (0.0..=1.0).contains(&crossover_probability),
52 "Invalid crossover probability"
53 );
54 assert!(
55 (0.into()..=2.into()).contains(&differential_weight),
56 "Invalid differential weight"
57 );
58
59 let mut population = (0..population_size)
61 .map(|_| from_fn(|i| rng.gen_range(bounds[i].0..bounds[i].1)))
62 .collect::<Box<[[T; N]]>>();
63
64 let index_range = 0..population.len();
66 for i in index_range.clone().cycle().take(num_iters) {
67 let mut j;
69 loop {
70 j = rng.gen_range(index_range.clone());
71 if j != i {
72 break;
73 }
74 }
75 let mut k;
76 loop {
77 k = rng.gen_range(index_range.clone());
78 if ![i, j].contains(&k) {
79 break;
80 }
81 }
82 let mut l;
83 loop {
84 l = rng.gen_range(index_range.clone());
85 if ![i, j, k].contains(&l) {
86 break;
87 }
88 }
89 let (j, k, l) = (j, k, l);
91
92 let r = rng.gen_range(0..N);
94 let x = from_fn(|n| {
95 if n == r || rng.gen::<f32>() < crossover_probability {
96 population[j][n] + (population[k][n] - population[l][n]) * differential_weight
97 } else {
98 population[i][n]
99 }
100 });
101
102 if compare_candidates(&x, &population[i]) {
104 population[i] = x;
105 }
106 }
107
108 let mut population = population.as_mut();
110 while population.len() > 1 {
111 let pop_len = population.len();
112 let half_len = pop_len / 2;
113 for i in 0..half_len {
114 if compare_candidates(&population[half_len + i], &population[i]) {
115 population[i] = population[half_len + i];
116 }
117 if pop_len % 2 == 1 {
118 population[half_len] = population[pop_len - 1];
119 }
120 }
121 population = &mut population[0..((pop_len + 1) / 2)];
122 }
123
124 population[0]
126}
127
128#[test]
129fn test_optimize_ackley() {
130 use crate::optimize::test_functions::ackley;
131
132 let mut rng = rand::thread_rng();
133 for _ in 0..10 {
134 let loc = optimize(
135 &[(-5.0, 5.0); 2],
136 20,
137 0.9,
138 0.8,
139 500,
140 &mut rng,
141 |lhs, rhs| ackley(lhs[0], lhs[1]) < ackley(rhs[0], rhs[1]),
142 );
143 assert!(loc[0].abs() < 0.1);
144 assert!(loc[1].abs() < 0.1);
145 }
146}