use crate::error::FittingError;
#[derive(Debug, Clone)]
pub struct NelderMeadConfig {
pub xatol: f64,
pub fatol: f64,
pub max_iter: usize,
pub initial_step_frac: f64,
pub initial_step_abs: f64,
}
impl Default for NelderMeadConfig {
fn default() -> Self {
Self {
xatol: 1e-4,
fatol: 1e-4,
max_iter: 5000,
initial_step_frac: 0.05,
initial_step_abs: 0.00025,
}
}
}
#[derive(Debug, Clone)]
pub struct NelderMeadResult {
pub x: Vec<f64>,
pub fun: f64,
pub iterations: usize,
pub n_evals: usize,
pub self_converged: bool,
}
pub fn nelder_mead_minimize<F>(
mut f: F,
x0: &[f64],
bounds: Option<&[(f64, f64)]>,
config: &NelderMeadConfig,
) -> Result<NelderMeadResult, FittingError>
where
F: FnMut(&[f64]) -> Result<f64, FittingError>,
{
let n = x0.len();
assert!(n > 0, "nelder_mead_minimize: x0 must not be empty");
if let Some(b) = bounds {
assert_eq!(
b.len(),
n,
"nelder_mead_minimize: bounds length {} != x0 length {}",
b.len(),
n
);
for (i, &(lo, hi)) in b.iter().enumerate() {
assert!(
lo <= hi,
"nelder_mead_minimize: bound {i} has lo {lo} > hi {hi}"
);
}
}
const ALPHA: f64 = 1.0; const GAMMA: f64 = 2.0; const RHO: f64 = 0.5; const SIGMA: f64 = 0.5;
let project = |x: &mut [f64]| {
if let Some(b) = bounds {
for (xi, &(lo, hi)) in x.iter_mut().zip(b.iter()) {
if *xi < lo {
*xi = 2.0 * lo - *xi; if *xi > hi {
*xi = hi;
}
if *xi < lo {
*xi = lo;
}
} else if *xi > hi {
*xi = 2.0 * hi - *xi;
if *xi < lo {
*xi = lo;
}
if *xi > hi {
*xi = hi;
}
}
}
}
};
let mut n_evals = 0usize;
let mut eval = |x: &[f64], f: &mut F| -> f64 {
n_evals += 1;
match f(x) {
Ok(v) if v.is_finite() => v,
_ => f64::INFINITY,
}
};
let mut simplex: Vec<Vec<f64>> = Vec::with_capacity(n + 1);
let mut fvals: Vec<f64> = Vec::with_capacity(n + 1);
let mut v0 = x0.to_vec();
project(&mut v0);
fvals.push(eval(&v0, &mut f));
simplex.push(v0.clone());
for i in 0..n {
let mut v = v0.clone();
let base = v[i];
let step = if base.abs() > 1e-8 {
config.initial_step_frac * base
} else {
config.initial_step_abs
};
v[i] = base + step;
project(&mut v);
if (v[i] - base).abs() < 1e-14 {
v[i] = base - step;
project(&mut v);
if (v[i] - base).abs() < 1e-14 {
v[i] = base
+ config
.initial_step_abs
.copysign(if base >= 0.0 { 1.0 } else { -1.0 });
project(&mut v);
}
}
fvals.push(eval(&v, &mut f));
simplex.push(v);
}
let mut order: Vec<usize> = (0..=n).collect();
order.sort_by(|&a, &b| {
fvals[a]
.partial_cmp(&fvals[b])
.unwrap_or(std::cmp::Ordering::Equal)
});
simplex = order.iter().map(|&i| simplex[i].clone()).collect();
fvals = order.iter().map(|&i| fvals[i]).collect();
let mut centroid = vec![0.0; n];
let mut xr = vec![0.0; n];
let mut xe = vec![0.0; n];
let mut xc = vec![0.0; n];
let mut iter = 0usize;
let mut self_converged = false;
while iter < config.max_iter {
iter += 1;
let fmin = fvals[0];
let fmax = fvals[n];
let frange = fmax - fmin;
let mut xrange = 0.0f64;
for v in simplex.iter() {
for (j, &xj) in v.iter().enumerate() {
let d = (xj - simplex[0][j]).abs();
if d > xrange {
xrange = d;
}
}
}
if xrange <= config.xatol && frange <= config.fatol {
self_converged = true;
break;
}
for (j, c) in centroid.iter_mut().enumerate() {
let mut s = 0.0;
for v in simplex.iter().take(n) {
s += v[j];
}
*c = s / (n as f64);
}
for j in 0..n {
xr[j] = centroid[j] + ALPHA * (centroid[j] - simplex[n][j]);
}
project(&mut xr);
let fxr = eval(&xr, &mut f);
if fvals[0] <= fxr && fxr < fvals[n - 1] {
simplex[n] = xr.clone();
fvals[n] = fxr;
} else if fxr < fvals[0] {
for j in 0..n {
xe[j] = centroid[j] + GAMMA * (xr[j] - centroid[j]);
}
project(&mut xe);
let fxe = eval(&xe, &mut f);
if fxe < fxr {
simplex[n] = xe.clone();
fvals[n] = fxe;
} else {
simplex[n] = xr.clone();
fvals[n] = fxr;
}
} else {
let (x_src, f_src) = if fxr < fvals[n] {
(&xr, fxr)
} else {
(&simplex[n], fvals[n])
};
for j in 0..n {
xc[j] = centroid[j] + RHO * (x_src[j] - centroid[j]);
}
project(&mut xc);
let fxc = eval(&xc, &mut f);
if fxc < f_src {
simplex[n] = xc.clone();
fvals[n] = fxc;
} else {
let best = simplex[0].clone();
for i in 1..=n {
for (j, xj) in simplex[i].iter_mut().enumerate() {
*xj = best[j] + SIGMA * (*xj - best[j]);
}
project(&mut simplex[i]);
fvals[i] = eval(&simplex[i], &mut f);
}
}
}
let mut order: Vec<usize> = (0..=n).collect();
order.sort_by(|&a, &b| {
fvals[a]
.partial_cmp(&fvals[b])
.unwrap_or(std::cmp::Ordering::Equal)
});
simplex = order.iter().map(|&i| simplex[i].clone()).collect();
fvals = order.iter().map(|&i| fvals[i]).collect();
}
Ok(NelderMeadResult {
x: simplex[0].clone(),
fun: fvals[0],
iterations: iter,
n_evals,
self_converged,
})
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_nm_quadratic_1d_converges() {
let f = |x: &[f64]| Ok((x[0] - 3.0).powi(2));
let cfg = NelderMeadConfig {
xatol: 1e-10,
fatol: 1e-12,
max_iter: 5000,
initial_step_frac: 0.1,
initial_step_abs: 0.01,
};
let r = nelder_mead_minimize(f, &[0.0], None, &cfg).unwrap();
assert!((r.x[0] - 3.0).abs() < 1e-6, "x = {:?}", r.x);
assert!(r.fun < 1e-12);
assert!(r.self_converged);
}
#[test]
fn test_nm_rosenbrock_2d() {
let f = |x: &[f64]| Ok((1.0 - x[0]).powi(2) + 100.0 * (x[1] - x[0].powi(2)).powi(2));
let cfg = NelderMeadConfig {
xatol: 1e-6,
fatol: 1e-8,
max_iter: 10_000,
initial_step_frac: 0.1,
initial_step_abs: 0.01,
};
let r = nelder_mead_minimize(f, &[-1.2, 1.0], None, &cfg).unwrap();
assert!(
(r.x[0] - 1.0).abs() < 1e-3 && (r.x[1] - 1.0).abs() < 1e-3,
"Rosenbrock minimizer off: x = {:?} fun = {}",
r.x,
r.fun
);
assert!(r.fun < 1e-6);
}
#[test]
fn test_nm_respects_bounds_reflection() {
let lo = 0.0;
let hi = 2.0;
let f = {
move |x: &[f64]| -> Result<f64, FittingError> {
assert!(
x[0] >= lo - 1e-12 && x[0] <= hi + 1e-12,
"NM passed out-of-bounds x = {}",
x[0]
);
Ok((x[0] - 5.0).powi(2))
}
};
let cfg = NelderMeadConfig::default();
let bounds = [(lo, hi)];
let r = nelder_mead_minimize(f, &[1.0], Some(&bounds), &cfg).unwrap();
assert!(
(r.x[0] - 2.0).abs() < 1e-2,
"expected x ≈ 2, got {}",
r.x[0]
);
assert!(r.x[0] >= lo - 1e-12 && r.x[0] <= hi + 1e-12);
}
#[test]
fn test_nm_handles_infeasible_objective() {
let f = |x: &[f64]| -> Result<f64, FittingError> {
if x[0] < 0.1 {
Err(FittingError::EvaluationFailed("x too small".into()))
} else {
Ok((x[0] - 0.5).powi(2))
}
};
let cfg = NelderMeadConfig {
xatol: 1e-8,
fatol: 1e-10,
max_iter: 5000,
initial_step_frac: 0.2,
initial_step_abs: 0.05,
};
let r = nelder_mead_minimize(f, &[1.0], None, &cfg).unwrap();
assert!(
(r.x[0] - 0.5).abs() < 1e-3,
"expected x ≈ 0.5, got {} (fun = {})",
r.x[0],
r.fun
);
}
}