Skip to main content

limma/
optim.rs

1//! Nelder–Mead simplex minimizer, a faithful port of R's `nmmin`
2//! (`src/appl/optim.c`), which backs `optim(method = "Nelder-Mead")`.
3//!
4//! The algorithm is deterministic, so with the same starting point, objective
5//! and control constants it reproduces R's iterate path — and therefore R's
6//! `optim` result — bit for bit. This is what limma's `genas` relies on.
7//!
8//! Only the plain (unconstrained, unscaled) Nelder–Mead is ported; parameter
9//! scaling (`parscale`) and the other `optim` methods are out of scope.
10
11/// Result of [`nelder_mead`], mirroring the parts of R's `optim` return value
12/// that callers need.
13#[derive(Debug, Clone)]
14pub struct NelderMead {
15    /// Best parameter vector found (R's `$par`).
16    pub par: Vec<f64>,
17    /// Objective value at `par` (R's `$value`).
18    pub value: f64,
19    /// Number of objective evaluations (R's `$counts[1]`).
20    pub fncount: usize,
21    /// Convergence flag (R's `$convergence`): 0 = converged, 1 = `maxit`
22    /// reached, 10 = simplex degenerated during shrink, 2 = objective was
23    /// non-finite at the starting point.
24    pub fail: i32,
25}
26
27// Sentinel substituted for non-finite objective values, matching R's `big`.
28const BIG: f64 = 1.0e35;
29
30/// Minimize `f` starting from `x0` using Nelder–Mead with R's default control
31/// constants: `reltol = sqrt(f64::EPSILON)`, `abstol = -inf`, `maxit = 500`,
32/// reflection `alpha = 1`, contraction `beta = 0.5`, expansion `gamma = 2`.
33pub fn nelder_mead<F>(x0: &[f64], f: F) -> NelderMead
34where
35    F: FnMut(&[f64]) -> f64,
36{
37    nelder_mead_with(
38        x0,
39        f,
40        f64::NEG_INFINITY,
41        f64::EPSILON.sqrt(),
42        500,
43        1.0,
44        0.5,
45        2.0,
46    )
47}
48
49/// Nelder–Mead with explicit control constants (see [`nelder_mead`] for the
50/// defaults). `abstol`/`reltol` are R's `optim` `abstol`/`reltol`.
51// The `for i in 0..n` / `for j in 0..n1` index loops below deliberately mirror
52// the index arithmetic of R's `nmmin` C source (`p[i][col]`, `bvec[i]` updated
53// in lockstep), so the iterate path stays bit-for-bit identical to R's.
54#[allow(clippy::too_many_arguments, clippy::needless_range_loop)]
55pub fn nelder_mead_with<F>(
56    x0: &[f64],
57    mut f: F,
58    abstol: f64,
59    reltol: f64,
60    maxit: usize,
61    alpha: f64,
62    beta: f64,
63    gamma: f64,
64) -> NelderMead
65where
66    F: FnMut(&[f64]) -> f64,
67{
68    let n = x0.len();
69    let mut bvec = x0.to_vec();
70
71    if maxit == 0 {
72        let value = f(&bvec);
73        return NelderMead {
74            par: bvec,
75            value,
76            fncount: 0,
77            fail: 0,
78        };
79    }
80    if n == 0 {
81        let value = f(&bvec);
82        return NelderMead {
83            par: bvec,
84            value,
85            fncount: 1,
86            fail: 0,
87        };
88    }
89
90    // P has n+1 rows (n coordinate rows + one value row at index `n`) and n+2
91    // columns (n+1 simplex vertices in columns 0..=n, plus a scratch/centroid
92    // column at index n+1). Mirrors R's `matrix(n, n+1)`.
93    let n1 = n + 1; // number of vertices
94    let cidx = n + 1; // scratch column index (R's C-1)
95    let vrow = n; // value row index (R's n1-1)
96    let mut p = vec![vec![0.0_f64; n + 2]; n + 1];
97
98    let f0 = f(&bvec);
99    if !f0.is_finite() {
100        return NelderMead {
101            par: bvec,
102            value: f0,
103            fncount: 1,
104            fail: 2,
105        };
106    }
107    let mut funcount = 1usize;
108    let convtol = reltol * (f0.abs() + reltol);
109
110    p[vrow][0] = f0;
111    for i in 0..n {
112        p[i][0] = bvec[i];
113    }
114
115    let mut l = 1usize; // 1-indexed best vertex
116    let mut size = 0.0;
117
118    let mut step = 0.0;
119    for i in 0..n {
120        let s = 0.1 * bvec[i].abs();
121        if s > step {
122            step = s;
123        }
124    }
125    if step == 0.0 {
126        step = 0.1;
127    }
128
129    // Build the remaining n vertices by perturbing one coordinate each.
130    for j in 2..=n1 {
131        for i in 0..n {
132            p[i][j - 1] = bvec[i];
133        }
134        let mut trystep = step;
135        while p[j - 2][j - 1] == bvec[j - 2] {
136            p[j - 2][j - 1] = bvec[j - 2] + trystep;
137            trystep *= 10.0;
138        }
139        size += trystep;
140    }
141    let mut oldsize = size;
142    let mut calcvert = true;
143    let mut fail = 0i32;
144
145    loop {
146        if calcvert {
147            for j in 0..n1 {
148                if j + 1 != l {
149                    for i in 0..n {
150                        bvec[i] = p[i][j];
151                    }
152                    let mut fj = f(&bvec);
153                    if !fj.is_finite() {
154                        fj = BIG;
155                    }
156                    funcount += 1;
157                    p[vrow][j] = fj;
158                }
159            }
160            calcvert = false;
161        }
162
163        let mut vl = p[vrow][l - 1];
164        let mut vh = vl;
165        let mut h = l;
166        for j in 1..=n1 {
167            if j != l {
168                let fj = p[vrow][j - 1];
169                if fj < vl {
170                    l = j;
171                    vl = fj;
172                }
173                if fj > vh {
174                    h = j;
175                    vh = fj;
176                }
177            }
178        }
179
180        if vh <= vl + convtol || vl <= abstol {
181            break;
182        }
183
184        // Centroid of all vertices except the worst (H), into the scratch col.
185        for i in 0..n {
186            let mut temp = -p[i][h - 1];
187            for j in 0..n1 {
188                temp += p[i][j];
189            }
190            p[i][cidx] = temp / n as f64;
191        }
192        // Reflect the worst vertex through the centroid.
193        for i in 0..n {
194            bvec[i] = (1.0 + alpha) * p[i][cidx] - alpha * p[i][h - 1];
195        }
196        let mut vr = f(&bvec);
197        if !vr.is_finite() {
198            vr = BIG;
199        }
200        funcount += 1;
201
202        if vr < vl {
203            // Reflection improved on the best: try to expand further.
204            p[vrow][cidx] = vr;
205            for i in 0..n {
206                let fe = gamma * bvec[i] + (1.0 - gamma) * p[i][cidx];
207                p[i][cidx] = bvec[i];
208                bvec[i] = fe;
209            }
210            let mut fe = f(&bvec);
211            if !fe.is_finite() {
212                fe = BIG;
213            }
214            funcount += 1;
215            if fe < vr {
216                for i in 0..n {
217                    p[i][h - 1] = bvec[i];
218                }
219                p[vrow][h - 1] = fe;
220            } else {
221                for i in 0..n {
222                    p[i][h - 1] = p[i][cidx];
223                }
224                p[vrow][h - 1] = vr;
225            }
226        } else {
227            // Reflection no better than the best.
228            if vr < vh {
229                for i in 0..n {
230                    p[i][h - 1] = bvec[i];
231                }
232                p[vrow][h - 1] = vr;
233            }
234            // Contract.
235            for i in 0..n {
236                bvec[i] = (1.0 - beta) * p[i][h - 1] + beta * p[i][cidx];
237            }
238            let mut fc = f(&bvec);
239            if !fc.is_finite() {
240                fc = BIG;
241            }
242            funcount += 1;
243            if fc < p[vrow][h - 1] {
244                for i in 0..n {
245                    p[i][h - 1] = bvec[i];
246                }
247                p[vrow][h - 1] = fc;
248            } else if vr >= vh {
249                // Shrink the whole simplex toward the best vertex.
250                calcvert = true;
251                size = 0.0;
252                for j in 0..n1 {
253                    if j + 1 != l {
254                        for i in 0..n {
255                            p[i][j] = beta * (p[i][j] - p[i][l - 1]) + p[i][l - 1];
256                            size += (p[i][j] - p[i][l - 1]).abs();
257                        }
258                    }
259                }
260                if size < oldsize {
261                    oldsize = size;
262                } else {
263                    fail = 10;
264                    break;
265                }
266            }
267        }
268
269        if funcount > maxit {
270            break;
271        }
272    }
273
274    let value = p[vrow][l - 1];
275    let par: Vec<f64> = (0..n).map(|i| p[i][l - 1]).collect();
276    if funcount > maxit {
277        fail = 1;
278    }
279    NelderMead {
280        par,
281        value,
282        fncount: funcount,
283        fail,
284    }
285}
286
287#[cfg(test)]
288#[allow(clippy::excessive_precision, clippy::approx_constant)]
289mod tests {
290    use super::*;
291
292    fn rclose(a: f64, b: f64) -> bool {
293        (a - b).abs() <= 1e-7 * (1.0 + b.abs())
294    }
295
296    // Rosenbrock: f(x,y) = 100*(y - x^2)^2 + (1 - x)^2, start (-1.2, 1).
297    #[test]
298    fn rosenbrock_matches_r_optim() {
299        let res = nelder_mead(&[-1.2, 1.0], |x| {
300            100.0 * (x[1] - x[0] * x[0]).powi(2) + (1.0 - x[0]).powi(2)
301        });
302        // optim(c(-1.2,1), fn, method="Nelder-Mead"):
303        // par = (1.0002601387256695, 1.000505999303765), value = 8.8252410967e-08,
304        // counts[1] = 195, convergence = 0.
305        assert!(
306            rclose(res.par[0], 1.0002601387256695),
307            "par0 {}",
308            res.par[0]
309        );
310        assert!(rclose(res.par[1], 1.000505999303765), "par1 {}", res.par[1]);
311        assert!(
312            rclose(res.value, 8.8252410967227472e-08),
313            "val {}",
314            res.value
315        );
316        assert_eq!(res.fncount, 195);
317        assert_eq!(res.fail, 0);
318    }
319
320    // Rotated quadratic: f = (x1-1)^2 + 4*(x2-2)^2 + (x1-1)*(x2-2), start (0,0).
321    #[test]
322    fn rotated_quadratic_matches_r_optim() {
323        let res = nelder_mead(&[0.0, 0.0], |x| {
324            let a = x[0] - 1.0;
325            let b = x[1] - 2.0;
326            a * a + 4.0 * b * b + a * b
327        });
328        // optim(c(0,0), fn, method="Nelder-Mead"):
329        // par = (1.000165999016468, 2.000030536283584), value = 3.6354524970e-08,
330        // counts[1] = 65, convergence = 0.
331        assert!(rclose(res.par[0], 1.000165999016468), "par0 {}", res.par[0]);
332        assert!(rclose(res.par[1], 2.000030536283584), "par1 {}", res.par[1]);
333        assert!(
334            rclose(res.value, 3.6354524970336173e-08),
335            "val {}",
336            res.value
337        );
338        assert_eq!(res.fncount, 65);
339        assert_eq!(res.fail, 0);
340    }
341
342    // Separable 3-D quadratic: f = sum (x - c)^2, c = (1,2,3), start (0,0,0).
343    #[test]
344    fn separable_3d_matches_r_optim() {
345        let c = [1.0, 2.0, 3.0];
346        let res = nelder_mead(&[0.0, 0.0, 0.0], |x| {
347            (0..3).map(|i| (x[i] - c[i]).powi(2)).sum()
348        });
349        // optim(c(0,0,0), fn, method="Nelder-Mead"):
350        // par = (1.0005383213703034, 1.9999848251163552, 2.9999188239843706),
351        // value = 2.9660972033e-07, counts[1] = 112, convergence = 0.
352        assert!(
353            rclose(res.par[0], 1.0005383213703034),
354            "par0 {}",
355            res.par[0]
356        );
357        assert!(
358            rclose(res.par[1], 1.9999848251163552),
359            "par1 {}",
360            res.par[1]
361        );
362        assert!(
363            rclose(res.par[2], 2.9999188239843706),
364            "par2 {}",
365            res.par[2]
366        );
367        assert!(
368            rclose(res.value, 2.9660972033243571e-07),
369            "val {}",
370            res.value
371        );
372        assert_eq!(res.fncount, 112);
373        assert_eq!(res.fail, 0);
374    }
375}