1#[derive(Debug, Clone)]
14pub struct NelderMead {
15 pub par: Vec<f64>,
17 pub value: f64,
19 pub fncount: usize,
21 pub fail: i32,
25}
26
27const BIG: f64 = 1.0e35;
29
30pub fn nelder_mead<F>(x0: &[f64], f: F) -> NelderMead
34where
35 F: FnMut(&[f64]) -> f64,
36{
37 nelder_mead_with(
38 x0,
39 f,
40 f64::NEG_INFINITY,
41 f64::EPSILON.sqrt(),
42 500,
43 1.0,
44 0.5,
45 2.0,
46 )
47}
48
49#[allow(clippy::too_many_arguments, clippy::needless_range_loop)]
55pub fn nelder_mead_with<F>(
56 x0: &[f64],
57 mut f: F,
58 abstol: f64,
59 reltol: f64,
60 maxit: usize,
61 alpha: f64,
62 beta: f64,
63 gamma: f64,
64) -> NelderMead
65where
66 F: FnMut(&[f64]) -> f64,
67{
68 let n = x0.len();
69 let mut bvec = x0.to_vec();
70
71 if maxit == 0 {
72 let value = f(&bvec);
73 return NelderMead {
74 par: bvec,
75 value,
76 fncount: 0,
77 fail: 0,
78 };
79 }
80 if n == 0 {
81 let value = f(&bvec);
82 return NelderMead {
83 par: bvec,
84 value,
85 fncount: 1,
86 fail: 0,
87 };
88 }
89
90 let n1 = n + 1; let cidx = n + 1; let vrow = n; let mut p = vec![vec![0.0_f64; n + 2]; n + 1];
97
98 let f0 = f(&bvec);
99 if !f0.is_finite() {
100 return NelderMead {
101 par: bvec,
102 value: f0,
103 fncount: 1,
104 fail: 2,
105 };
106 }
107 let mut funcount = 1usize;
108 let convtol = reltol * (f0.abs() + reltol);
109
110 p[vrow][0] = f0;
111 for i in 0..n {
112 p[i][0] = bvec[i];
113 }
114
115 let mut l = 1usize; let mut size = 0.0;
117
118 let mut step = 0.0;
119 for i in 0..n {
120 let s = 0.1 * bvec[i].abs();
121 if s > step {
122 step = s;
123 }
124 }
125 if step == 0.0 {
126 step = 0.1;
127 }
128
129 for j in 2..=n1 {
131 for i in 0..n {
132 p[i][j - 1] = bvec[i];
133 }
134 let mut trystep = step;
135 while p[j - 2][j - 1] == bvec[j - 2] {
136 p[j - 2][j - 1] = bvec[j - 2] + trystep;
137 trystep *= 10.0;
138 }
139 size += trystep;
140 }
141 let mut oldsize = size;
142 let mut calcvert = true;
143 let mut fail = 0i32;
144
145 loop {
146 if calcvert {
147 for j in 0..n1 {
148 if j + 1 != l {
149 for i in 0..n {
150 bvec[i] = p[i][j];
151 }
152 let mut fj = f(&bvec);
153 if !fj.is_finite() {
154 fj = BIG;
155 }
156 funcount += 1;
157 p[vrow][j] = fj;
158 }
159 }
160 calcvert = false;
161 }
162
163 let mut vl = p[vrow][l - 1];
164 let mut vh = vl;
165 let mut h = l;
166 for j in 1..=n1 {
167 if j != l {
168 let fj = p[vrow][j - 1];
169 if fj < vl {
170 l = j;
171 vl = fj;
172 }
173 if fj > vh {
174 h = j;
175 vh = fj;
176 }
177 }
178 }
179
180 if vh <= vl + convtol || vl <= abstol {
181 break;
182 }
183
184 for i in 0..n {
186 let mut temp = -p[i][h - 1];
187 for j in 0..n1 {
188 temp += p[i][j];
189 }
190 p[i][cidx] = temp / n as f64;
191 }
192 for i in 0..n {
194 bvec[i] = (1.0 + alpha) * p[i][cidx] - alpha * p[i][h - 1];
195 }
196 let mut vr = f(&bvec);
197 if !vr.is_finite() {
198 vr = BIG;
199 }
200 funcount += 1;
201
202 if vr < vl {
203 p[vrow][cidx] = vr;
205 for i in 0..n {
206 let fe = gamma * bvec[i] + (1.0 - gamma) * p[i][cidx];
207 p[i][cidx] = bvec[i];
208 bvec[i] = fe;
209 }
210 let mut fe = f(&bvec);
211 if !fe.is_finite() {
212 fe = BIG;
213 }
214 funcount += 1;
215 if fe < vr {
216 for i in 0..n {
217 p[i][h - 1] = bvec[i];
218 }
219 p[vrow][h - 1] = fe;
220 } else {
221 for i in 0..n {
222 p[i][h - 1] = p[i][cidx];
223 }
224 p[vrow][h - 1] = vr;
225 }
226 } else {
227 if vr < vh {
229 for i in 0..n {
230 p[i][h - 1] = bvec[i];
231 }
232 p[vrow][h - 1] = vr;
233 }
234 for i in 0..n {
236 bvec[i] = (1.0 - beta) * p[i][h - 1] + beta * p[i][cidx];
237 }
238 let mut fc = f(&bvec);
239 if !fc.is_finite() {
240 fc = BIG;
241 }
242 funcount += 1;
243 if fc < p[vrow][h - 1] {
244 for i in 0..n {
245 p[i][h - 1] = bvec[i];
246 }
247 p[vrow][h - 1] = fc;
248 } else if vr >= vh {
249 calcvert = true;
251 size = 0.0;
252 for j in 0..n1 {
253 if j + 1 != l {
254 for i in 0..n {
255 p[i][j] = beta * (p[i][j] - p[i][l - 1]) + p[i][l - 1];
256 size += (p[i][j] - p[i][l - 1]).abs();
257 }
258 }
259 }
260 if size < oldsize {
261 oldsize = size;
262 } else {
263 fail = 10;
264 break;
265 }
266 }
267 }
268
269 if funcount > maxit {
270 break;
271 }
272 }
273
274 let value = p[vrow][l - 1];
275 let par: Vec<f64> = (0..n).map(|i| p[i][l - 1]).collect();
276 if funcount > maxit {
277 fail = 1;
278 }
279 NelderMead {
280 par,
281 value,
282 fncount: funcount,
283 fail,
284 }
285}
286
287#[cfg(test)]
288#[allow(clippy::excessive_precision, clippy::approx_constant)]
289mod tests {
290 use super::*;
291
292 fn rclose(a: f64, b: f64) -> bool {
293 (a - b).abs() <= 1e-7 * (1.0 + b.abs())
294 }
295
296 #[test]
298 fn rosenbrock_matches_r_optim() {
299 let res = nelder_mead(&[-1.2, 1.0], |x| {
300 100.0 * (x[1] - x[0] * x[0]).powi(2) + (1.0 - x[0]).powi(2)
301 });
302 assert!(
306 rclose(res.par[0], 1.0002601387256695),
307 "par0 {}",
308 res.par[0]
309 );
310 assert!(rclose(res.par[1], 1.000505999303765), "par1 {}", res.par[1]);
311 assert!(
312 rclose(res.value, 8.8252410967227472e-08),
313 "val {}",
314 res.value
315 );
316 assert_eq!(res.fncount, 195);
317 assert_eq!(res.fail, 0);
318 }
319
320 #[test]
322 fn rotated_quadratic_matches_r_optim() {
323 let res = nelder_mead(&[0.0, 0.0], |x| {
324 let a = x[0] - 1.0;
325 let b = x[1] - 2.0;
326 a * a + 4.0 * b * b + a * b
327 });
328 assert!(rclose(res.par[0], 1.000165999016468), "par0 {}", res.par[0]);
332 assert!(rclose(res.par[1], 2.000030536283584), "par1 {}", res.par[1]);
333 assert!(
334 rclose(res.value, 3.6354524970336173e-08),
335 "val {}",
336 res.value
337 );
338 assert_eq!(res.fncount, 65);
339 assert_eq!(res.fail, 0);
340 }
341
342 #[test]
344 fn separable_3d_matches_r_optim() {
345 let c = [1.0, 2.0, 3.0];
346 let res = nelder_mead(&[0.0, 0.0, 0.0], |x| {
347 (0..3).map(|i| (x[i] - c[i]).powi(2)).sum()
348 });
349 assert!(
353 rclose(res.par[0], 1.0005383213703034),
354 "par0 {}",
355 res.par[0]
356 );
357 assert!(
358 rclose(res.par[1], 1.9999848251163552),
359 "par1 {}",
360 res.par[1]
361 );
362 assert!(
363 rclose(res.par[2], 2.9999188239843706),
364 "par2 {}",
365 res.par[2]
366 );
367 assert!(
368 rclose(res.value, 2.9660972033243571e-07),
369 "val {}",
370 res.value
371 );
372 assert_eq!(res.fncount, 112);
373 assert_eq!(res.fail, 0);
374 }
375}