use crate::error::Result;
use num_traits::Float;
use super::{OptimizeConfig, OptimizeResult};
pub fn nelder_mead<T, F>(
f: F,
x0: &[T],
config: Option<OptimizeConfig<T>>,
) -> Result<OptimizeResult<T>>
where
T: Float + std::fmt::Debug + std::iter::Sum,
F: Fn(&[T]) -> T,
{
let cfg = config.unwrap_or_default();
let n = x0.len();
let alpha = T::one(); let gamma = T::from(2.0).expect("2.0 should be representable in Float"); let rho = T::from(0.5).expect("0.5 should be representable in Float"); let sigma = T::from(0.5).expect("0.5 should be representable in Float");
let mut simplex: Vec<Vec<T>> = Vec::with_capacity(n + 1);
simplex.push(x0.to_vec());
for i in 0..n {
let mut vertex = x0.to_vec();
vertex[i] = vertex[i] + T::from(0.05).expect("0.05 should be representable in Float");
simplex.push(vertex);
}
let mut f_vals: Vec<T> = simplex.iter().map(|x| f(x)).collect();
let mut nfev = n + 1;
for iter in 0..cfg.max_iter {
let mut indices: Vec<usize> = (0..n + 1).collect();
indices.sort_by(|&i, &j| {
f_vals[i]
.partial_cmp(&f_vals[j])
.unwrap_or(std::cmp::Ordering::Equal)
});
let best_idx = indices[0];
let worst_idx = indices[n];
let second_worst_idx = indices[n - 1];
let f_range = f_vals[worst_idx] - f_vals[best_idx];
if f_range < cfg.ftol {
return Ok(OptimizeResult {
x: simplex[best_idx].clone(),
fun: f_vals[best_idx],
grad: vec![T::zero(); n], nit: iter,
nfev,
njev: 0,
success: true,
message: "Optimization terminated successfully (simplex converged)".to_string(),
});
}
let mut centroid = vec![T::zero(); n];
for &idx in indices.iter().take(n) {
for j in 0..n {
centroid[j] = centroid[j] + simplex[idx][j];
}
}
for j in 0..n {
centroid[j] = centroid[j] / T::from(n).expect("n should be representable in Float");
}
let x_r: Vec<T> = (0..n)
.map(|j| centroid[j] + alpha * (centroid[j] - simplex[worst_idx][j]))
.collect();
let f_r = f(&x_r);
nfev += 1;
if f_r < f_vals[best_idx] {
let x_e: Vec<T> = (0..n)
.map(|j| centroid[j] + gamma * (x_r[j] - centroid[j]))
.collect();
let f_e = f(&x_e);
nfev += 1;
if f_e < f_r {
simplex[worst_idx] = x_e;
f_vals[worst_idx] = f_e;
} else {
simplex[worst_idx] = x_r;
f_vals[worst_idx] = f_r;
}
} else if f_r < f_vals[second_worst_idx] {
simplex[worst_idx] = x_r;
f_vals[worst_idx] = f_r;
} else {
let (x_c, use_reflection) = if f_r < f_vals[worst_idx] {
let x_c: Vec<T> = (0..n)
.map(|j| centroid[j] + rho * (x_r[j] - centroid[j]))
.collect();
(x_c, true)
} else {
let x_c: Vec<T> = (0..n)
.map(|j| centroid[j] - rho * (simplex[worst_idx][j] - centroid[j]))
.collect();
(x_c, false)
};
let f_c = f(&x_c);
nfev += 1;
if (use_reflection && f_c < f_r) || (!use_reflection && f_c < f_vals[worst_idx]) {
simplex[worst_idx] = x_c;
f_vals[worst_idx] = f_c;
} else {
for i in 1..=n {
let idx = indices[i];
for j in 0..n {
simplex[idx][j] =
simplex[best_idx][j] + sigma * (simplex[idx][j] - simplex[best_idx][j]);
}
f_vals[idx] = f(&simplex[idx]);
nfev += 1;
}
}
}
}
let best_idx = f_vals
.iter()
.enumerate()
.min_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(idx, _)| idx)
.expect("f_vals should not be empty at end of Nelder-Mead iteration");
Ok(OptimizeResult {
x: simplex[best_idx].clone(),
fun: f_vals[best_idx],
grad: vec![T::zero(); n],
nit: cfg.max_iter,
nfev,
njev: 0,
success: false,
message: "Maximum iterations reached".to_string(),
})
}