use core::marker::PhantomData;
use crate::core::constraint::BoxConstraints;
use crate::core::math::{ClampInPlace, ScaledAdd};
use crate::core::problem::{CostFunction, Problem};
use crate::core::solver::Solver;
use crate::core::state::BasicSimplexState;
use crate::core::termination::TerminationReason;
pub struct NelderMead<Mode = Unbounded> {
config: ParamConfig,
params: Option<Params>,
_mode: PhantomData<fn() -> Mode>,
}
pub struct Unbounded;
pub struct Projected;
#[derive(Clone, Copy)]
struct Params {
alpha: f64,
beta: f64,
gamma: f64,
delta: f64,
}
#[derive(Clone, Copy)]
enum ParamConfig {
Standard,
Adaptive,
Fixed(Params),
}
impl NelderMead<Unbounded> {
pub fn standard() -> Self {
Self {
config: ParamConfig::Standard,
params: None,
_mode: PhantomData,
}
}
pub fn adaptive() -> Self {
Self {
config: ParamConfig::Adaptive,
params: None,
_mode: PhantomData,
}
}
pub fn with_params(alpha: f64, beta: f64, gamma: f64, delta: f64) -> Self {
assert!(alpha > 0.0, "α must be > 0");
assert!(beta > 1.0, "β must be > 1");
assert!(gamma > 0.0 && gamma < 1.0, "γ must be in (0, 1)");
assert!(delta > 0.0 && delta < 1.0, "δ must be in (0, 1)");
Self {
config: ParamConfig::Fixed(Params {
alpha,
beta,
gamma,
delta,
}),
params: None,
_mode: PhantomData,
}
}
pub fn projected(self) -> NelderMead<Projected> {
NelderMead {
config: self.config,
params: self.params,
_mode: PhantomData,
}
}
}
impl<Mode> NelderMead<Mode> {
fn resolve(config: ParamConfig, n: usize) -> Params {
assert!(n >= 1, "NelderMead requires at least a 1-D problem");
match config {
ParamConfig::Standard => Params {
alpha: 1.0,
beta: 2.0,
gamma: 0.5,
delta: 0.5,
},
ParamConfig::Adaptive => {
let n = n as f64;
Params {
alpha: 1.0,
beta: 1.0 + 2.0 / n,
gamma: 0.75 - 1.0 / (2.0 * n),
delta: 1.0 - 1.0 / n,
}
}
ParamConfig::Fixed(p) => p,
}
}
}
fn affine<V: Clone + ScaledAdd<f64>>(a: &V, b: &V, t: f64) -> V {
let mut out = a.clone();
out.scaled_add(-t, a);
out.scaled_add(t, b);
out
}
fn centroid<V: Clone + ScaledAdd<f64>>(vertices: &[V]) -> V {
let inv = 1.0 / vertices.len() as f64;
let mut c = vertices[0].clone();
c.scaled_add(inv - 1.0, &vertices[0]);
for v in &vertices[1..] {
c.scaled_add(inv, v);
}
c
}
fn sort_simplex<V>(vertices: &mut [V], costs: &mut [f64]) {
let n = vertices.len();
let mut idx: Vec<usize> = (0..n).collect();
idx.sort_by(|&i, &j| {
costs[i]
.partial_cmp(&costs[j])
.unwrap_or(std::cmp::Ordering::Equal)
});
apply_permutation(vertices, &idx);
apply_permutation(costs, &idx);
}
fn apply_permutation<T>(slice: &mut [T], idx: &[usize]) {
let mut visited = vec![false; slice.len()];
for start in 0..slice.len() {
if visited[start] || idx[start] == start {
visited[start] = true;
continue;
}
let mut current = start;
loop {
let next = idx[current];
visited[current] = true;
if next == start {
break;
}
slice.swap(current, next);
current = next;
}
}
}
fn init_costs_and_sort<P, V>(
problem: &mut Problem<P>,
state: &mut BasicSimplexState<V>,
) -> Result<(), P::Error>
where
P: CostFunction<Param = V, Output = f64>,
{
for (v, c) in state.vertices.iter().zip(state.costs.iter_mut()) {
*c = problem.cost(v)?;
}
sort_simplex(&mut state.vertices, &mut state.costs);
Ok(())
}
fn next_iter_inner<P, V, F>(
problem: &mut Problem<P>,
mut state: BasicSimplexState<V>,
p: Params,
project: &F,
) -> Result<(BasicSimplexState<V>, Option<TerminationReason>), P::Error>
where
P: CostFunction<Param = V, Output = f64>,
V: Clone + ScaledAdd<f64>,
F: Fn(&mut V),
{
let m = state.vertices.len();
let n = m - 1;
let worst = m - 1;
let x_bar = centroid(&state.vertices[..n]);
let f1 = state.costs[0];
let fn_ = state.costs[n - 1];
let fnp1 = state.costs[worst];
let mut x_r = affine(&x_bar, &state.vertices[worst], -p.alpha);
project(&mut x_r);
let fr = problem.cost(&x_r)?;
if f1 <= fr && fr < fn_ {
state.vertices[worst] = x_r;
state.costs[worst] = fr;
} else if fr < f1 {
let mut x_e = affine(&x_bar, &x_r, p.beta);
project(&mut x_e);
let fe = problem.cost(&x_e)?;
if fe < fr {
state.vertices[worst] = x_e;
state.costs[worst] = fe;
} else {
state.vertices[worst] = x_r;
state.costs[worst] = fr;
}
} else if fr < fnp1 {
let mut x_oc = affine(&x_bar, &x_r, p.gamma);
project(&mut x_oc);
let foc = problem.cost(&x_oc)?;
if foc <= fr {
state.vertices[worst] = x_oc;
state.costs[worst] = foc;
} else {
shrink_inner(problem, &mut state, p.delta, project)?;
}
} else {
let mut x_ic = affine(&x_bar, &state.vertices[worst], p.gamma);
project(&mut x_ic);
let fic = problem.cost(&x_ic)?;
if fic < fnp1 {
state.vertices[worst] = x_ic;
state.costs[worst] = fic;
} else {
shrink_inner(problem, &mut state, p.delta, project)?;
}
}
sort_simplex(&mut state.vertices, &mut state.costs);
Ok((state, None))
}
fn shrink_inner<P, V, F>(
problem: &mut Problem<P>,
state: &mut BasicSimplexState<V>,
delta: f64,
project: &F,
) -> Result<(), P::Error>
where
P: CostFunction<Param = V, Output = f64>,
V: Clone + ScaledAdd<f64>,
F: Fn(&mut V),
{
let (best_slice, rest) = state.vertices.split_at_mut(1);
let best = &best_slice[0];
for (v, c) in rest.iter_mut().zip(&mut state.costs[1..]) {
let mut new_v = affine(best, v, delta);
project(&mut new_v);
*v = new_v;
*c = problem.cost(v)?;
}
Ok(())
}
impl<P, V> Solver<P, BasicSimplexState<V>> for NelderMead<Unbounded>
where
P: CostFunction<Param = V, Output = f64>,
V: Clone + ScaledAdd<f64>,
{
type Error = P::Error;
fn init(
&mut self,
problem: &mut Problem<P>,
mut state: BasicSimplexState<V>,
) -> Result<BasicSimplexState<V>, Self::Error> {
let n = state.vertices.len() - 1;
self.params = Some(Self::resolve(self.config, n));
init_costs_and_sort(problem, &mut state)?;
Ok(state)
}
fn next_iter(
&mut self,
problem: &mut Problem<P>,
state: BasicSimplexState<V>,
) -> Result<(BasicSimplexState<V>, Option<TerminationReason>), Self::Error> {
let p = self
.params
.expect("NelderMead::init must run before next_iter");
next_iter_inner(problem, state, p, &|_: &mut V| {})
}
}
impl<P, V> Solver<P, BasicSimplexState<V>> for NelderMead<Projected>
where
P: CostFunction<Param = V, Output = f64> + BoxConstraints,
V: Clone + ScaledAdd<f64> + ClampInPlace,
{
type Error = P::Error;
fn init(
&mut self,
problem: &mut Problem<P>,
mut state: BasicSimplexState<V>,
) -> Result<BasicSimplexState<V>, Self::Error> {
let n = state.vertices.len() - 1;
self.params = Some(Self::resolve(self.config, n));
let lo = problem.inner().lower().clone();
let hi = problem.inner().upper().clone();
for v in state.vertices.iter_mut() {
v.clamp_in_place(&lo, &hi);
}
init_costs_and_sort(problem, &mut state)?;
Ok(state)
}
fn next_iter(
&mut self,
problem: &mut Problem<P>,
state: BasicSimplexState<V>,
) -> Result<(BasicSimplexState<V>, Option<TerminationReason>), Self::Error> {
let p = self
.params
.expect("NelderMead::init must run before next_iter");
let lo = problem.inner().lower().clone();
let hi = problem.inner().upper().clone();
next_iter_inner(problem, state, p, &|v: &mut V| v.clamp_in_place(&lo, &hi))
}
}