#[derive(Clone, Debug)]
pub struct Minimum {
pub x: Vec<f64>,
pub f: f64,
pub iterations: usize,
pub converged: bool,
}
#[derive(Copy, Clone, Debug)]
pub struct NelderMeadOptions {
pub max_iter: usize,
pub ftol: f64,
pub xtol: f64,
pub step_frac: f64,
}
impl Default for NelderMeadOptions {
fn default() -> Self {
Self {
max_iter: 500,
ftol: 1.0e-8,
xtol: 1.0e-8,
step_frac: 0.05,
}
}
}
#[allow(clippy::needless_range_loop)] pub fn nelder_mead<F>(mut f: F, x0: &[f64], opts: NelderMeadOptions) -> Minimum
where
F: FnMut(&[f64]) -> f64,
{
let n = x0.len();
assert!(n >= 1);
let mut simplex: Vec<Vec<f64>> = Vec::with_capacity(n + 1);
simplex.push(x0.to_vec());
for i in 0..n {
let mut v = x0.to_vec();
let delta = if x0[i].abs() > 1e-12 {
opts.step_frac * x0[i].abs()
} else {
opts.step_frac.max(1e-4)
};
v[i] += delta;
simplex.push(v);
}
let mut values: Vec<f64> = simplex.iter().map(|v| f(v)).collect();
let mut iteration = 0_usize;
let mut converged = false;
while iteration < opts.max_iter {
iteration += 1;
let mut order: Vec<usize> = (0..=n).collect();
order.sort_by(|a, b| values[*a].partial_cmp(&values[*b]).unwrap());
let best = order[0];
let worst = order[n];
let second_worst = order[n - 1];
let f_spread = values[worst] - values[best];
let x_spread = (0..n)
.map(|i| {
let mut hi = simplex[best][i];
let mut lo = simplex[best][i];
for &idx in &order {
hi = hi.max(simplex[idx][i]);
lo = lo.min(simplex[idx][i]);
}
hi - lo
})
.fold(0.0_f64, f64::max);
if f_spread.abs() < opts.ftol && x_spread < opts.xtol {
converged = true;
break;
}
let mut centroid = vec![0.0_f64; n];
for (i, &idx) in order.iter().enumerate() {
if i == n {
continue;
}
for j in 0..n {
centroid[j] += simplex[idx][j];
}
}
for j in 0..n {
centroid[j] /= n as f64;
}
let mut x_r = vec![0.0_f64; n];
for j in 0..n {
x_r[j] = centroid[j] + (centroid[j] - simplex[worst][j]);
}
let f_r = f(&x_r);
if f_r < values[second_worst] && f_r >= values[best] {
simplex[worst] = x_r;
values[worst] = f_r;
continue;
}
if f_r < values[best] {
let mut x_e = vec![0.0_f64; n];
for j in 0..n {
x_e[j] = centroid[j] + 2.0 * (x_r[j] - centroid[j]);
}
let f_e = f(&x_e);
if f_e < f_r {
simplex[worst] = x_e;
values[worst] = f_e;
} else {
simplex[worst] = x_r;
values[worst] = f_r;
}
continue;
}
let mut x_c = vec![0.0_f64; n];
for j in 0..n {
x_c[j] = centroid[j] + 0.5 * (simplex[worst][j] - centroid[j]);
}
let f_c = f(&x_c);
if f_c < values[worst] {
simplex[worst] = x_c;
values[worst] = f_c;
continue;
}
for &idx in &order[1..] {
let mut shrunk = vec![0.0_f64; n];
for j in 0..n {
shrunk[j] = simplex[best][j] + 0.5 * (simplex[idx][j] - simplex[best][j]);
}
values[idx] = f(&shrunk);
simplex[idx] = shrunk;
}
}
let best_idx = (0..=n)
.min_by(|a, b| values[*a].partial_cmp(&values[*b]).unwrap())
.unwrap();
Minimum {
x: simplex[best_idx].clone(),
f: values[best_idx],
iterations: iteration,
converged,
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn quadratic_bowl_converges() {
let target = [1.0_f64, -2.0, 3.0, 0.5];
let target_c = target;
let f = move |x: &[f64]| -> f64 {
x.iter()
.zip(target_c.iter())
.map(|(a, b)| (a - b).powi(2))
.sum()
};
let x0 = [0.0_f64; 4];
let m = nelder_mead(f, &x0, NelderMeadOptions::default());
assert!(m.converged, "should converge within {} iters", m.iterations);
for (got, want) in m.x.iter().zip(target.iter()) {
assert!((got - want).abs() < 1.0e-5, "got {} vs {}", got, want);
}
assert!(m.f < 1.0e-8);
}
#[test]
fn rosenbrock_converges() {
let f = |x: &[f64]| -> f64 {
let (a, b) = (x[0], x[1]);
(1.0 - a).powi(2) + 100.0 * (b - a * a).powi(2)
};
let m = nelder_mead(
f,
&[-1.2_f64, 1.0],
NelderMeadOptions {
max_iter: 2000,
..Default::default()
},
);
assert!((m.x[0] - 1.0).abs() < 1e-3 && (m.x[1] - 1.0).abs() < 1e-3);
}
#[test]
fn max_iter_cap_respected() {
let mut calls = 0;
let f = |x: &[f64]| {
calls += 1;
x.iter().map(|v| v.sin()).sum()
};
let m = nelder_mead(
f,
&[0.1_f64, 0.2, 0.3],
NelderMeadOptions {
max_iter: 5,
..Default::default()
},
);
assert!(m.iterations <= 5);
}
}