1use crate::error::OptimError;
8use crate::types::{IterationRecord, OptimResult, OptimStatus};
9use numra_core::Scalar;
10use rand::rngs::SmallRng;
11use rand::{Rng, SeedableRng};
12
13#[derive(Clone, Debug)]
15pub struct DEOptions<S: Scalar> {
16 pub pop_size: Option<usize>,
18 pub max_generations: usize,
20 pub differential_weight: S,
22 pub crossover_rate: S,
24 pub tol: S,
26 pub seed: u64,
28 pub verbose: bool,
30}
31
32impl<S: Scalar> Default for DEOptions<S> {
33 fn default() -> Self {
34 Self {
35 pop_size: None,
36 max_generations: 1000,
37 differential_weight: S::from_f64(0.8),
38 crossover_rate: S::from_f64(0.9),
39 tol: S::from_f64(1e-12),
40 seed: 42,
41 verbose: false,
42 }
43 }
44}
45
46pub fn de_minimize<S: Scalar, F>(
62 f: F,
63 bounds: &[(S, S)],
64 opts: &DEOptions<S>,
65) -> Result<OptimResult<S>, OptimError>
66where
67 F: Fn(&[S]) -> S,
68{
69 let start = std::time::Instant::now();
70 let n = bounds.len();
71 let np = opts.pop_size.unwrap_or(10 * n);
72 let cr = opts.crossover_rate;
73 let f_weight = opts.differential_weight;
74
75 let mut rng = SmallRng::seed_from_u64(opts.seed);
76
77 let mut pop: Vec<Vec<S>> = (0..np)
79 .map(|_| {
80 bounds
81 .iter()
82 .map(|&(lo, hi)| lo + S::from_f64(rng.gen::<f64>()) * (hi - lo))
83 .collect()
84 })
85 .collect();
86
87 let mut fitness: Vec<S> = pop.iter().map(|xi| f(xi)).collect();
89 let mut n_feval: usize = np;
90
91 let mut best_idx: usize = 0;
93 for i in 1..np {
94 if fitness[i] < fitness[best_idx] {
95 best_idx = i;
96 }
97 }
98 let mut best_x = pop[best_idx].clone();
99 let mut best_f = fitness[best_idx];
100
101 let mut history: Vec<IterationRecord<S>> = Vec::new();
102 let mut converged = false;
103 let mut gen = 0usize;
104
105 for g in 0..opts.max_generations {
107 gen = g + 1;
108
109 for i in 0..np {
110 let (a, b, c) = pick_three(&mut rng, np, i);
112
113 let mut trial = vec![S::ZERO; n];
115 let j_rand = rng.gen_range(0..n);
116
117 for j in 0..n {
118 let v_j = pop[a][j] + f_weight * (pop[b][j] - pop[c][j]);
119 let v_j = v_j.clamp(bounds[j].0, bounds[j].1);
121
122 if S::from_f64(rng.gen::<f64>()) < cr || j == j_rand {
124 trial[j] = v_j;
125 } else {
126 trial[j] = pop[i][j];
127 }
128 }
129
130 let trial_f = f(&trial);
132 n_feval += 1;
133
134 if trial_f <= fitness[i] {
135 pop[i] = trial;
136 fitness[i] = trial_f;
137
138 if trial_f < best_f {
139 best_f = trial_f;
140 best_x = pop[i].clone();
141 }
142 }
143 }
144
145 history.push(IterationRecord {
147 iteration: gen,
148 objective: best_f,
149 gradient_norm: S::ZERO,
150 step_size: S::ZERO,
151 constraint_violation: S::ZERO,
152 });
153
154 if opts.verbose {
155 eprintln!("DE gen {}: best_f = {:.6e}", gen, best_f.to_f64());
156 }
157
158 let f_max = fitness.iter().copied().fold(S::NEG_INFINITY, S::max);
160 let f_min = fitness.iter().copied().fold(S::INFINITY, S::min);
161 if (f_max - f_min) < opts.tol {
162 converged = true;
163 break;
164 }
165 }
166
167 let (status, message) = if converged {
168 (
169 OptimStatus::GradientConverged,
170 format!(
171 "Converged: fitness spread < tol ({:.2e}) at generation {}",
172 opts.tol.to_f64(),
173 gen
174 ),
175 )
176 } else {
177 (
178 OptimStatus::MaxIterations,
179 format!(
180 "Reached maximum generations ({}), best f = {:.6e}",
181 opts.max_generations,
182 best_f.to_f64()
183 ),
184 )
185 };
186
187 let result = OptimResult {
188 x: best_x,
189 f: best_f,
190 grad: Vec::new(),
191 iterations: gen,
192 n_feval,
193 n_geval: 0,
194 converged,
195 message,
196 status,
197 history,
198 lambda_eq: Vec::new(),
199 lambda_ineq: Vec::new(),
200 active_bounds: Vec::new(),
201 constraint_violation: S::ZERO,
202 wall_time_secs: 0.0,
203 pareto: None,
204 sensitivity: None,
205 }
206 .with_wall_time(start);
207
208 Ok(result)
209}
210
211fn pick_three(rng: &mut SmallRng, n: usize, exclude: usize) -> (usize, usize, usize) {
213 let mut a = rng.gen_range(0..n);
214 while a == exclude {
215 a = rng.gen_range(0..n);
216 }
217 let mut b = rng.gen_range(0..n);
218 while b == exclude || b == a {
219 b = rng.gen_range(0..n);
220 }
221 let mut c = rng.gen_range(0..n);
222 while c == exclude || c == a || c == b {
223 c = rng.gen_range(0..n);
224 }
225 (a, b, c)
226}
227
228#[cfg(test)]
229mod tests {
230 use super::*;
231
232 #[test]
233 fn test_de_sphere() {
234 let bounds = vec![(-5.0, 5.0); 3];
236 let result = de_minimize(
237 |x: &[f64]| x.iter().map(|xi| xi * xi).sum::<f64>(),
238 &bounds,
239 &DEOptions {
240 max_generations: 500,
241 ..Default::default()
242 },
243 )
244 .unwrap();
245 assert!(result.converged, "did not converge: {}", result.message);
246 for &xi in &result.x {
247 assert!(xi.abs() < 0.1, "xi={}", xi);
248 }
249 assert!(result.f < 0.01, "f={}", result.f);
250 }
251
252 #[test]
253 fn test_de_rastrigin() {
254 let bounds = vec![(-5.12, 5.12); 2];
256 let result = de_minimize(
257 |x: &[f64]| {
258 let n = x.len() as f64;
259 10.0 * n
260 + x.iter()
261 .map(|xi| xi * xi - 10.0 * (2.0 * std::f64::consts::PI * xi).cos())
262 .sum::<f64>()
263 },
264 &bounds,
265 &DEOptions {
266 max_generations: 1000,
267 ..Default::default()
268 },
269 )
270 .unwrap();
271 assert!(result.f < 1.0, "f={}", result.f);
272 }
273
274 #[test]
275 fn test_de_rosenbrock() {
276 let bounds = vec![(-5.0, 10.0), (-5.0, 10.0)];
277 let result = de_minimize(
278 |x: &[f64]| (1.0 - x[0]).powi(2) + 100.0 * (x[1] - x[0] * x[0]).powi(2),
279 &bounds,
280 &DEOptions {
281 max_generations: 2000,
282 ..Default::default()
283 },
284 )
285 .unwrap();
286 assert!(result.f < 0.01, "f={}", result.f);
287 assert!((result.x[0] - 1.0).abs() < 0.1, "x0={}", result.x[0]);
288 }
289
290 #[test]
291 fn test_de_deterministic() {
292 let bounds = vec![(-5.0, 5.0); 2];
293 let f = |x: &[f64]| x[0] * x[0] + x[1] * x[1];
294 let r1 = de_minimize(f, &bounds, &DEOptions::default()).unwrap();
295 let r2 = de_minimize(f, &bounds, &DEOptions::default()).unwrap();
296 assert_eq!(r1.x, r2.x);
297 assert_eq!(r1.f, r2.f);
298 }
299
300 #[test]
301 fn test_de_1d() {
302 let bounds = vec![(0.0, 10.0)];
303 let result = de_minimize(
304 |x: &[f64]| (x[0] - 3.0).powi(2),
305 &bounds,
306 &DEOptions::default(),
307 )
308 .unwrap();
309 assert!((result.x[0] - 3.0).abs() < 0.1, "x={}", result.x[0]);
310 }
311}