use core::marker::PhantomData;
use crate::core::constraint::BoxConstraints;
use crate::core::math::{ClampInPlace, Scalar, ScaledAdd};
use crate::core::problem::{CostFunction, Problem};
use crate::core::solver::Solver;
use crate::core::state::BasicSimplexState;
use crate::core::termination::TerminationReason;
pub struct NelderMead<Mode = Unbounded, F = f64> {
config: ParamConfig<F>,
params: Option<Params<F>>,
_mode: PhantomData<fn() -> Mode>,
}
pub struct Unbounded;
pub struct Projected;
#[derive(Clone, Copy)]
struct Params<F> {
alpha: F,
beta: F,
gamma: F,
delta: F,
}
#[derive(Clone, Copy)]
enum ParamConfig<F> {
Standard,
Adaptive,
Fixed(Params<F>),
}
impl<F: Scalar> NelderMead<Unbounded, F> {
pub fn standard() -> Self {
Self {
config: ParamConfig::Standard,
params: None,
_mode: PhantomData,
}
}
pub fn adaptive() -> Self {
Self {
config: ParamConfig::Adaptive,
params: None,
_mode: PhantomData,
}
}
pub fn with_params(alpha: F, beta: F, gamma: F, delta: F) -> Self {
assert!(alpha > F::zero(), "α must be > 0");
assert!(beta > F::one(), "β must be > 1");
assert!(gamma > F::zero() && gamma < F::one(), "γ must be in (0, 1)");
assert!(delta > F::zero() && delta < F::one(), "δ must be in (0, 1)");
Self {
config: ParamConfig::Fixed(Params {
alpha,
beta,
gamma,
delta,
}),
params: None,
_mode: PhantomData,
}
}
pub fn projected(self) -> NelderMead<Projected, F> {
NelderMead {
config: self.config,
params: self.params,
_mode: PhantomData,
}
}
}
impl<Mode, F: Scalar> NelderMead<Mode, F> {
fn resolve(config: ParamConfig<F>, n: usize) -> Params<F> {
assert!(n >= 1, "NelderMead requires at least a 1-D problem");
match config {
ParamConfig::Standard => {
let half = F::from_f64(0.5).unwrap();
Params {
alpha: F::one(),
beta: F::from_f64(2.0).unwrap(),
gamma: half,
delta: half,
}
}
ParamConfig::Adaptive => {
let n = F::from_usize(n).unwrap();
let two = F::from_f64(2.0).unwrap();
Params {
alpha: F::one(),
beta: F::one() + two / n,
gamma: F::from_f64(0.75).unwrap() - F::one() / (two * n),
delta: F::one() - F::one() / n,
}
}
ParamConfig::Fixed(p) => p,
}
}
}
fn affine<V, F>(a: &V, b: &V, t: F) -> V
where
V: Clone + ScaledAdd<F>,
F: Scalar,
{
let mut out = a.clone();
out.scaled_add(-t, a);
out.scaled_add(t, b);
out
}
fn centroid<V, F>(vertices: &[V]) -> V
where
V: Clone + ScaledAdd<F>,
F: Scalar,
{
let inv = F::from_usize(vertices.len()).unwrap().recip();
let mut c = vertices[0].clone();
c.scaled_add(inv - F::one(), &vertices[0]);
for v in &vertices[1..] {
c.scaled_add(inv, v);
}
c
}
fn sort_simplex<V, F: PartialOrd>(vertices: &mut [V], costs: &mut [F]) {
let n = vertices.len();
let mut idx: Vec<usize> = (0..n).collect();
idx.sort_by(|&i, &j| {
costs[i]
.partial_cmp(&costs[j])
.unwrap_or(std::cmp::Ordering::Equal)
});
apply_permutation(vertices, &idx);
apply_permutation(costs, &idx);
}
fn apply_permutation<T>(slice: &mut [T], idx: &[usize]) {
let mut visited = vec![false; slice.len()];
for start in 0..slice.len() {
if visited[start] || idx[start] == start {
visited[start] = true;
continue;
}
let mut current = start;
loop {
let next = idx[current];
visited[current] = true;
if next == start {
break;
}
slice.swap(current, next);
current = next;
}
}
}
fn init_costs_and_sort<P, V, F>(
problem: &mut Problem<P>,
state: &mut BasicSimplexState<V, F>,
) -> Result<(), P::Error>
where
F: Scalar,
P: CostFunction<Param = V, Output = F>,
{
for (v, c) in state.vertices.iter().zip(state.costs.iter_mut()) {
*c = problem.cost(v)?;
}
sort_simplex(&mut state.vertices, &mut state.costs);
Ok(())
}
#[allow(clippy::type_complexity)]
fn next_iter_inner<P, V, F, Proj>(
problem: &mut Problem<P>,
mut state: BasicSimplexState<V, F>,
p: Params<F>,
project: &Proj,
) -> Result<(BasicSimplexState<V, F>, Option<TerminationReason>), P::Error>
where
F: Scalar,
P: CostFunction<Param = V, Output = F>,
V: Clone + ScaledAdd<F>,
Proj: Fn(&mut V),
{
let m = state.vertices.len();
let n = m - 1;
let worst = m - 1;
let x_bar = centroid(&state.vertices[..n]);
let f1 = state.costs[0];
let fn_ = state.costs[n - 1];
let fnp1 = state.costs[worst];
let mut x_r = affine(&x_bar, &state.vertices[worst], -p.alpha);
project(&mut x_r);
let fr = problem.cost(&x_r)?;
if f1 <= fr && fr < fn_ {
state.vertices[worst] = x_r;
state.costs[worst] = fr;
} else if fr < f1 {
let mut x_e = affine(&x_bar, &x_r, p.beta);
project(&mut x_e);
let fe = problem.cost(&x_e)?;
if fe < fr {
state.vertices[worst] = x_e;
state.costs[worst] = fe;
} else {
state.vertices[worst] = x_r;
state.costs[worst] = fr;
}
} else if fr < fnp1 {
let mut x_oc = affine(&x_bar, &x_r, p.gamma);
project(&mut x_oc);
let foc = problem.cost(&x_oc)?;
if foc <= fr {
state.vertices[worst] = x_oc;
state.costs[worst] = foc;
} else {
shrink_inner(problem, &mut state, p.delta, project)?;
}
} else {
let mut x_ic = affine(&x_bar, &state.vertices[worst], p.gamma);
project(&mut x_ic);
let fic = problem.cost(&x_ic)?;
if fic < fnp1 {
state.vertices[worst] = x_ic;
state.costs[worst] = fic;
} else {
shrink_inner(problem, &mut state, p.delta, project)?;
}
}
sort_simplex(&mut state.vertices, &mut state.costs);
Ok((state, None))
}
fn shrink_inner<P, V, F, Proj>(
problem: &mut Problem<P>,
state: &mut BasicSimplexState<V, F>,
delta: F,
project: &Proj,
) -> Result<(), P::Error>
where
F: Scalar,
P: CostFunction<Param = V, Output = F>,
V: Clone + ScaledAdd<F>,
Proj: Fn(&mut V),
{
let (best_slice, rest) = state.vertices.split_at_mut(1);
let best = &best_slice[0];
for (v, c) in rest.iter_mut().zip(&mut state.costs[1..]) {
let mut new_v = affine(best, v, delta);
project(&mut new_v);
*v = new_v;
*c = problem.cost(v)?;
}
Ok(())
}
impl<P, V, F> Solver<P, BasicSimplexState<V, F>> for NelderMead<Unbounded, F>
where
F: Scalar,
P: CostFunction<Param = V, Output = F>,
V: Clone + ScaledAdd<F>,
{
type Error = P::Error;
fn init(
&mut self,
problem: &mut Problem<P>,
mut state: BasicSimplexState<V, F>,
) -> Result<BasicSimplexState<V, F>, Self::Error> {
let n = state.vertices.len() - 1;
self.params = Some(Self::resolve(self.config, n));
init_costs_and_sort(problem, &mut state)?;
Ok(state)
}
fn next_iter(
&mut self,
problem: &mut Problem<P>,
state: BasicSimplexState<V, F>,
) -> Result<(BasicSimplexState<V, F>, Option<TerminationReason>), Self::Error> {
let p = self
.params
.expect("NelderMead::init must run before next_iter");
next_iter_inner(problem, state, p, &|_: &mut V| {})
}
}
impl<P, V, F> Solver<P, BasicSimplexState<V, F>> for NelderMead<Projected, F>
where
F: Scalar,
P: CostFunction<Param = V, Output = F> + BoxConstraints,
V: Clone + ScaledAdd<F> + ClampInPlace,
{
type Error = P::Error;
fn init(
&mut self,
problem: &mut Problem<P>,
mut state: BasicSimplexState<V, F>,
) -> Result<BasicSimplexState<V, F>, Self::Error> {
let n = state.vertices.len() - 1;
self.params = Some(Self::resolve(self.config, n));
let lo = problem.inner().lower().clone();
let hi = problem.inner().upper().clone();
for v in state.vertices.iter_mut() {
v.clamp_in_place(&lo, &hi);
}
init_costs_and_sort(problem, &mut state)?;
Ok(state)
}
fn next_iter(
&mut self,
problem: &mut Problem<P>,
state: BasicSimplexState<V, F>,
) -> Result<(BasicSimplexState<V, F>, Option<TerminationReason>), Self::Error> {
let p = self
.params
.expect("NelderMead::init must run before next_iter");
let lo = problem.inner().lower().clone();
let hi = problem.inner().upper().clone();
next_iter_inner(problem, state, p, &|v: &mut V| v.clamp_in_place(&lo, &hi))
}
}