use core::marker::PhantomData;
use crate::core::constraint::BoxConstraints;
use crate::core::math::{ClampInPlace, Scalar, ScaleInPlace, ScaledAdd};
use crate::core::problem::{CostFunction, Problem};
use crate::core::solver::Solver;
use crate::core::state::BasicSimplexState;
use crate::core::termination::TerminationReason;
pub struct NelderMead<Mode = Unbounded, F = f64> {
config: ParamConfig<F>,
params: Option<Params<F>>,
_mode: PhantomData<fn() -> Mode>,
}
pub struct Unbounded;
pub struct Projected;
#[derive(Clone, Copy)]
struct Params<F> {
alpha: F,
beta: F,
gamma: F,
delta: F,
}
#[derive(Clone, Copy)]
enum ParamConfig<F> {
Standard,
Adaptive,
Fixed(Params<F>),
}
impl<F: Scalar> Default for NelderMead<Unbounded, F> {
fn default() -> Self {
Self::new()
}
}
impl<F: Scalar> NelderMead<Unbounded, F> {
pub fn new() -> Self {
Self {
config: ParamConfig::Standard,
params: None,
_mode: PhantomData,
}
}
pub fn adaptive() -> Self {
Self {
config: ParamConfig::Adaptive,
params: None,
_mode: PhantomData,
}
}
pub fn with_params(alpha: F, beta: F, gamma: F, delta: F) -> Self {
assert!(alpha > F::zero(), "α must be > 0");
assert!(beta > F::one(), "β must be > 1");
assert!(gamma > F::zero() && gamma < F::one(), "γ must be in (0, 1)");
assert!(delta > F::zero() && delta < F::one(), "δ must be in (0, 1)");
Self {
config: ParamConfig::Fixed(Params {
alpha,
beta,
gamma,
delta,
}),
params: None,
_mode: PhantomData,
}
}
pub fn projected(self) -> NelderMead<Projected, F> {
NelderMead {
config: self.config,
params: self.params,
_mode: PhantomData,
}
}
}
impl<Mode, F: Scalar> NelderMead<Mode, F> {
fn resolve(config: ParamConfig<F>, n: usize) -> Params<F> {
assert!(n >= 1, "NelderMead requires at least a 1-D problem");
match config {
ParamConfig::Standard => {
let half = F::from_f64(0.5).unwrap();
Params {
alpha: F::one(),
beta: F::from_f64(2.0).unwrap(),
gamma: half,
delta: half,
}
}
ParamConfig::Adaptive => {
let n = F::from_usize(n).unwrap();
let two = F::from_f64(2.0).unwrap();
Params {
alpha: F::one(),
beta: F::one() + two / n,
gamma: F::from_f64(0.75).unwrap() - F::one() / (two * n),
delta: F::one() - F::one() / n,
}
}
ParamConfig::Fixed(p) => p,
}
}
}
fn affine_into<V, F>(out: &mut V, a: &V, b: &V, t: F)
where
V: ScaleInPlace<F> + ScaledAdd<F>,
F: Scalar,
{
out.scale_in_place(F::zero());
out.scaled_add(F::one() - t, a);
out.scaled_add(t, b);
}
fn centroid_into<V, F>(out: &mut V, vertices: &[V])
where
V: ScaleInPlace<F> + ScaledAdd<F>,
F: Scalar,
{
let inv = F::from_usize(vertices.len()).unwrap().recip();
out.scale_in_place(F::zero());
for v in vertices {
out.scaled_add(inv, v);
}
}
fn insertion_sort_simplex<V, F: PartialOrd>(vertices: &mut [V], costs: &mut [F]) {
for i in 1..vertices.len() {
let mut j = i;
while j > 0
&& matches!(
costs[j].partial_cmp(&costs[j - 1]),
Some(std::cmp::Ordering::Less)
)
{
vertices.swap(j, j - 1);
costs.swap(j, j - 1);
j -= 1;
}
}
}
fn init_costs_and_sort<P, V, F>(
problem: &mut Problem<P>,
state: &mut BasicSimplexState<V, F>,
) -> Result<(), P::Error>
where
F: Scalar,
P: CostFunction<Param = V, Output = F>,
{
for (v, c) in state.vertices.iter().zip(state.costs.iter_mut()) {
*c = problem.cost(v)?;
}
insertion_sort_simplex(&mut state.vertices, &mut state.costs);
Ok(())
}
fn ensure_scratch<V, F>(state: &mut BasicSimplexState<V, F>)
where
V: Clone,
{
if state.scratch.len() < 3 {
state.scratch = vec![state.vertices[0].clone(); 3];
}
}
#[allow(clippy::type_complexity)]
fn next_iter_inner<P, V, F, Proj>(
problem: &mut Problem<P>,
mut state: BasicSimplexState<V, F>,
p: Params<F>,
project: &Proj,
) -> Result<(BasicSimplexState<V, F>, Option<TerminationReason>), P::Error>
where
F: Scalar,
P: CostFunction<Param = V, Output = F>,
V: ScaleInPlace<F> + ScaledAdd<F>,
Proj: Fn(&mut V),
{
let m = state.vertices.len();
let n = m - 1;
let worst = m - 1;
let f1 = state.costs[0];
let fn_ = state.costs[n - 1];
let fnp1 = state.costs[worst];
let (xbar_slice, rest) = state.scratch.split_at_mut(1);
let (xr_slice, alt_slice) = rest.split_at_mut(1);
let x_bar = &mut xbar_slice[0];
let x_r = &mut xr_slice[0];
let x_alt = &mut alt_slice[0];
centroid_into(x_bar, &state.vertices[..n]);
affine_into(x_r, x_bar, &state.vertices[worst], -p.alpha);
project(x_r);
let fr = problem.cost(x_r)?;
if f1 <= fr && fr < fn_ {
std::mem::swap(&mut state.vertices[worst], x_r);
state.costs[worst] = fr;
} else if fr < f1 {
affine_into(x_alt, x_bar, x_r, p.beta);
project(x_alt);
let fe = problem.cost(x_alt)?;
if fe < fr {
std::mem::swap(&mut state.vertices[worst], x_alt);
state.costs[worst] = fe;
} else {
std::mem::swap(&mut state.vertices[worst], x_r);
state.costs[worst] = fr;
}
} else if fr < fnp1 {
affine_into(x_alt, x_bar, x_r, p.gamma);
project(x_alt);
let foc = problem.cost(x_alt)?;
if foc <= fr {
std::mem::swap(&mut state.vertices[worst], x_alt);
state.costs[worst] = foc;
} else {
shrink_inner(problem, &mut state, p.delta, project)?;
}
} else {
affine_into(x_alt, x_bar, &state.vertices[worst], p.gamma);
project(x_alt);
let fic = problem.cost(x_alt)?;
if fic < fnp1 {
std::mem::swap(&mut state.vertices[worst], x_alt);
state.costs[worst] = fic;
} else {
shrink_inner(problem, &mut state, p.delta, project)?;
}
}
insertion_sort_simplex(&mut state.vertices, &mut state.costs);
Ok((state, None))
}
fn shrink_inner<P, V, F, Proj>(
problem: &mut Problem<P>,
state: &mut BasicSimplexState<V, F>,
delta: F,
project: &Proj,
) -> Result<(), P::Error>
where
F: Scalar,
P: CostFunction<Param = V, Output = F>,
V: ScaleInPlace<F> + ScaledAdd<F>,
Proj: Fn(&mut V),
{
let one = F::one();
let (best_slice, rest) = state.vertices.split_at_mut(1);
let best = &best_slice[0];
for (v, c) in rest.iter_mut().zip(&mut state.costs[1..]) {
v.scale_in_place(delta);
v.scaled_add(one - delta, best);
project(v);
*c = problem.cost(v)?;
}
Ok(())
}
impl<P, V, F> Solver<P, BasicSimplexState<V, F>> for NelderMead<Unbounded, F>
where
F: Scalar,
P: CostFunction<Param = V, Output = F>,
V: Clone + ScaleInPlace<F> + ScaledAdd<F>,
{
type Error = P::Error;
fn init(
&mut self,
problem: &mut Problem<P>,
mut state: BasicSimplexState<V, F>,
) -> Result<BasicSimplexState<V, F>, Self::Error> {
let n = state.vertices.len() - 1;
self.params = Some(Self::resolve(self.config, n));
ensure_scratch(&mut state);
init_costs_and_sort(problem, &mut state)?;
Ok(state)
}
fn next_iter(
&mut self,
problem: &mut Problem<P>,
state: BasicSimplexState<V, F>,
) -> Result<(BasicSimplexState<V, F>, Option<TerminationReason>), Self::Error> {
let p = self
.params
.expect("NelderMead::init must run before next_iter");
next_iter_inner(problem, state, p, &|_: &mut V| {})
}
}
impl<P, V, F> Solver<P, BasicSimplexState<V, F>> for NelderMead<Projected, F>
where
F: Scalar,
P: CostFunction<Param = V, Output = F> + BoxConstraints,
V: Clone + ScaleInPlace<F> + ScaledAdd<F> + ClampInPlace,
{
type Error = P::Error;
fn init(
&mut self,
problem: &mut Problem<P>,
mut state: BasicSimplexState<V, F>,
) -> Result<BasicSimplexState<V, F>, Self::Error> {
let n = state.vertices.len() - 1;
self.params = Some(Self::resolve(self.config, n));
let lo = problem.inner().lower().clone();
let hi = problem.inner().upper().clone();
for v in state.vertices.iter_mut() {
v.clamp_in_place(&lo, &hi);
}
ensure_scratch(&mut state);
init_costs_and_sort(problem, &mut state)?;
Ok(state)
}
fn next_iter(
&mut self,
problem: &mut Problem<P>,
state: BasicSimplexState<V, F>,
) -> Result<(BasicSimplexState<V, F>, Option<TerminationReason>), Self::Error> {
let p = self
.params
.expect("NelderMead::init must run before next_iter");
let lo = problem.inner().lower().clone();
let hi = problem.inner().upper().clone();
next_iter_inner(problem, state, p, &|v: &mut V| v.clamp_in_place(&lo, &hi))
}
}