use crate::{
algorithms::gradient_free::GradientFreeStatus,
core::{Bounds, Callbacks, MinimizationSummary, Point},
error::{GaneshError, GaneshResult},
traits::algorithm::{resolve_bounds_and_transform, BoundsHandlingMode},
traits::{
Algorithm, CheckpointableAlgorithm, CostFunction, Status, SupportsBounds,
SupportsParameterNames, SupportsTransform, Terminator, Transform,
},
DMatrix, DVector, Float,
};
use serde::{Deserialize, Serialize};
use std::{fmt::Debug, ops::ControlFlow};
#[derive(Debug, Clone)]
pub enum SimplexConstructionMethod {
ScaledOrthogonal {
x0: DVector<Float>,
orthogonal_multiplier: Float,
orthogonal_zero_step: Float,
},
Orthogonal {
x0: DVector<Float>,
simplex_size: Float,
},
Custom {
simplex: Vec<DVector<Float>>,
},
}
impl SimplexConstructionMethod {
fn starting_point(&self) -> DVector<Float> {
match self {
Self::ScaledOrthogonal { x0, .. } | Self::Orthogonal { x0, .. } => x0.clone(),
Self::Custom { simplex } => simplex.first().cloned().unwrap_or_default(),
}
}
pub fn scaled_orthogonal<I>(x0: I) -> Self
where
I: AsRef<[Float]>,
{
Self::ScaledOrthogonal {
x0: DVector::from_row_slice(x0.as_ref()),
orthogonal_multiplier: 1.05,
orthogonal_zero_step: 0.00025,
}
}
pub fn custom_scaled_orthogonal<I>(
x0: I,
orthogonal_multiplier: Float,
orthogonal_zero_step: Float,
) -> GaneshResult<Self>
where
I: AsRef<[Float]>,
{
if orthogonal_multiplier <= 0.0 {
return Err(GaneshError::ConfigError(
"orthogonal_multiplier must be greater than 0".to_string(),
));
}
if orthogonal_zero_step <= 0.0 {
return Err(GaneshError::ConfigError(
"orthogonal_zero_step must be greater than 0".to_string(),
));
}
Ok(Self::ScaledOrthogonal {
x0: DVector::from_row_slice(x0.as_ref()),
orthogonal_multiplier,
orthogonal_zero_step,
})
}
pub fn orthogonal<I>(x0: I) -> Self
where
I: AsRef<[Float]>,
{
Self::Orthogonal {
x0: DVector::from_row_slice(x0.as_ref()),
simplex_size: 1.0,
}
}
pub fn custom_orthogonal<I>(x0: I, simplex_size: Float) -> GaneshResult<Self>
where
I: AsRef<[Float]>,
{
if simplex_size <= 0.0 {
return Err(GaneshError::ConfigError(
"simplex_size must be greater than 0".to_string(),
));
}
Ok(Self::Orthogonal {
x0: DVector::from_row_slice(x0.as_ref()),
simplex_size,
})
}
pub fn custom<I>(simplex: I) -> GaneshResult<Self>
where
I: AsRef<[DVector<Float>]>,
{
let simplex = simplex.as_ref();
let Some(first) = simplex.first() else {
return Err(GaneshError::ConfigError(
"Custom simplex must not be empty".to_string(),
));
};
if first.len() < 2 {
return Err(GaneshError::ConfigError(
"Nelder-Mead is only a suitable method for problems of dimension >= 2".to_string(),
));
}
if simplex.iter().any(|point| point.len() != first.len()) {
return Err(GaneshError::ConfigError(
"Custom simplex points must all have the same dimension".to_string(),
));
}
if simplex.len() != first.len() + 1 {
return Err(GaneshError::ConfigError(
"Custom simplex must contain exactly n + 1 points for dimension n".to_string(),
));
}
Ok(Self::Custom {
simplex: simplex.to_vec(),
})
}
}
#[derive(Clone)]
pub struct NelderMeadInit {
construction_method: SimplexConstructionMethod,
}
impl NelderMeadInit {
pub fn new<I>(x0: I) -> Self
where
I: AsRef<[Float]>,
{
Self {
construction_method: SimplexConstructionMethod::scaled_orthogonal(x0),
}
}
pub const fn new_with_method(construction_method: SimplexConstructionMethod) -> Self {
Self {
construction_method,
}
}
pub fn custom<I>(simplex: I) -> GaneshResult<Self>
where
I: AsRef<[DVector<Float>]>,
{
Ok(Self {
construction_method: SimplexConstructionMethod::custom(simplex)?,
})
}
fn starting_point(&self) -> DVector<Float> {
self.construction_method.starting_point()
}
}
impl SimplexConstructionMethod {
fn generate<U, E>(
&self,
func: &dyn CostFunction<U, E>,
transform: &Option<Box<dyn Transform>>,
bounds: Option<&Bounds>,
args: &U,
) -> Result<Simplex, E> {
match self {
Self::ScaledOrthogonal {
x0,
orthogonal_multiplier,
orthogonal_zero_step,
} => {
let mut points = Vec::default();
let mut point_0 = Point::from(transform.to_internal(x0).into_owned());
point_0.evaluate_transformed(func, transform, args)?;
points.push(point_0.clone());
let dim = point_0.x.len();
assert!(
dim >= 2,
"Nelder-Mead is only a suitable method for problems of dimension >= 2"
);
for i in 0..dim {
let mut point_i = point_0.clone();
if point_i.x[i] == 0.0 {
point_i.x[i] = *orthogonal_zero_step;
} else {
point_i.x[i] *= *orthogonal_multiplier;
}
if let Some(bounds) = bounds {
point_i.x.iter_mut().zip(bounds.iter()).for_each(|(v, b)| {
if *v > b.0.upper() {
*v = Float::mul_add(2.0, b.0.upper(), -(*v));
}
});
}
if let Some(b) = bounds {
point_i.x = b.clip_values(&point_i.x);
}
point_i.fx = None;
point_i.evaluate_transformed(func, transform, args)?;
points.push(point_i);
}
Ok(Simplex::new(&points))
}
Self::Orthogonal { x0, simplex_size } => {
let mut points = Vec::default();
let mut point_0 = Point::from(transform.to_internal(x0).into_owned());
point_0.evaluate_transformed(func, transform, args)?;
points.push(point_0.clone());
let dim = point_0.x.len();
assert!(
dim >= 2,
"Nelder-Mead is only a suitable method for problems of dimension >= 2"
);
for i in 0..dim {
let mut point_i = point_0.clone();
point_i.x[i] += *simplex_size;
if let Some(bounds) = bounds {
point_i.x.iter_mut().zip(bounds.iter()).for_each(|(v, b)| {
if *v > b.0.upper() {
*v = Float::mul_add(2.0, b.0.upper(), -(*v));
}
});
}
if let Some(b) = bounds {
point_i.x = b.clip_values(&point_i.x);
}
point_i.fx = None;
point_i.evaluate_transformed(func, transform, args)?;
points.push(point_i);
}
Ok(Simplex::new(&points))
}
Self::Custom { simplex } => {
assert!(!simplex.is_empty());
assert!(simplex.len() == simplex[0].len() + 1);
assert!(simplex.len() > 2);
Ok(Simplex::new(
&simplex
.iter()
.map(|x| {
let mut point_i = Point::from(transform.to_internal(x).into_owned());
if let Some(bounds) = bounds {
point_i.x.iter_mut().zip(bounds.iter()).for_each(|(v, b)| {
if *v > b.0.upper() {
*v = Float::mul_add(2.0, b.0.upper(), -(*v));
}
});
}
if let Some(b) = bounds {
point_i.x = b.clip_values(&point_i.x);
}
point_i.evaluate_transformed(func, transform, args)?;
Ok(point_i)
})
.collect::<Result<Vec<Point<DVector<Float>>>, E>>()?,
))
}
}
}
}
#[derive(Default, Clone, Serialize, Deserialize)]
pub struct Simplex {
points: Vec<Point<DVector<Float>>>,
dimension: usize,
sorted: bool,
total_centroid: DVector<Float>,
volume: Float,
initial_best: Point<DVector<Float>>,
initial_worst: Point<DVector<Float>>,
initial_volume: Float,
}
impl Debug for Simplex {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:#?}", self.points)
}
}
impl Simplex {
fn new(points: &[Point<DVector<Float>>]) -> Self {
let mut sorted_points = points.to_vec();
sorted_points.sort_by(|a, b| a.total_cmp(b));
let initial_best = sorted_points[0].clone();
let initial_worst = sorted_points[sorted_points.len() - 1].clone();
let n_params = points.len() - 1;
let diffs: Vec<DVector<Float>> = sorted_points
.iter()
.skip(1)
.map(|p| &p.x - &initial_best.x)
.collect();
let gram_mat = DMatrix::from_fn(n_params, n_params, |i, j| diffs[i].dot(&diffs[j]));
let volume = Float::sqrt(gram_mat.determinant());
let total_centroid =
sorted_points.iter().map(|p| &p.x).sum::<DVector<Float>>() / points.len() as Float;
Self {
points: sorted_points,
dimension: points.len(),
sorted: false,
total_centroid,
volume,
initial_best,
initial_worst,
initial_volume: volume,
}
}
fn size(&self) -> usize {
self.points.len()
}
fn corrected_centroid(&self) -> DVector<Float> {
let n = self.points.len();
let total = &self.total_centroid * (n as Float);
let sum = total - &self.points[n - 1].x;
sum / ((n - 1) as Float)
}
fn best_position(&self, transform: &Option<Box<dyn Transform>>) -> (DVector<Float>, Float) {
let best = self.best();
(transform.to_owned_external(&best.x), best.fx_checked())
}
fn best(&self) -> &Point<DVector<Float>> {
&self.points[0]
}
fn worst(&self) -> &Point<DVector<Float>> {
&self.points[self.points.len() - 1]
}
fn second_worst(&self) -> &Point<DVector<Float>> {
&self.points[self.points.len() - 2]
}
fn insert_and_sort(&mut self, index: usize, element: Point<DVector<Float>>) {
let removed = self.points.remove(self.points.len() - 1);
let n = self.points.len() as Float + 1.0;
self.total_centroid += (&element.x - &removed.x) / n;
self.points.insert(index, element);
self.sorted = false;
self.sort();
}
fn insert_sorted(&mut self, index: usize, element: Point<DVector<Float>>) {
let removed = self.points.remove(self.points.len() - 1);
self.points.insert(index, element);
self.sorted = true;
let n = self.points.len() as Float;
self.total_centroid += (&self.points[index].x - &removed.x) / n;
}
fn sort(&mut self) {
if !self.sorted {
self.sorted = true;
self.points.sort_by(|a, b| a.total_cmp(b));
}
}
fn compute_total_centroid(&mut self) {
let n = self.points.len() as Float;
self.total_centroid = self.points.iter().map(|p| &p.x).sum::<DVector<Float>>() / n;
}
fn scale_volume(&mut self, factor: Float) {
self.volume *= factor;
}
}
#[derive(Default, Debug, Clone)]
pub enum SimplexExpansionMethod {
#[default]
GreedyMinimization,
GreedyExpansion,
}
#[derive(Debug, Clone)]
pub enum NelderMeadFTerminator {
Amoeba {
eps_rel: Float,
},
Absolute {
eps_abs: Float,
},
StdDev {
eps_abs: Float,
},
}
impl Default for NelderMeadFTerminator {
fn default() -> Self {
Self::StdDev {
eps_abs: Float::EPSILON.powf(0.25),
}
}
}
impl<P, U, E> Terminator<NelderMead, P, GradientFreeStatus, U, E, NelderMeadConfig>
for NelderMeadFTerminator
where
P: CostFunction<U, E>,
{
fn check_for_termination(
&mut self,
_current_step: usize,
algorithm: &mut NelderMead,
_problem: &P,
status: &mut GradientFreeStatus,
_args: &U,
_config: &NelderMeadConfig,
) -> ControlFlow<()> {
let simplex = &algorithm.simplex;
match self {
Self::Amoeba { eps_rel: eps_f_rel } => {
let fh = simplex.worst().fx_checked();
let fl = simplex.best().fx_checked();
if 2.0 * (fh - fl) / (Float::abs(fh) + Float::abs(fl)) <= *eps_f_rel {
status.set_message().succeed_with_message("term_f = AMOEBA");
return ControlFlow::Break(());
}
}
Self::Absolute { eps_abs: eps_f_abs } => {
let fh = simplex.worst().fx_checked();
let fl = simplex.best().fx_checked();
if fh - fl <= *eps_f_abs {
status
.set_message()
.succeed_with_message("term_f = ABSOLUTE");
return ControlFlow::Break(());
}
}
Self::StdDev { eps_abs: eps_f_abs } => {
let dim = simplex.dimension as Float;
let mean = simplex
.points
.iter()
.map(|point| point.fx_checked())
.sum::<Float>()
/ dim;
let std_dev = Float::sqrt(
simplex
.points
.iter()
.map(|point| Float::powi(point.fx_checked() - mean, 2))
.sum::<Float>()
/ dim,
);
if std_dev <= *eps_f_abs {
status.set_message().succeed_with_message("term_f = STDDEV");
return ControlFlow::Break(());
}
}
}
ControlFlow::Continue(())
}
}
#[derive(Debug, Clone)]
pub enum NelderMeadXTerminator {
Diameter {
eps_abs: Float,
},
Higham {
eps_rel: Float,
},
Rowan {
eps_rel: Float,
},
Singer {
eps_rel: Float,
},
}
impl Default for NelderMeadXTerminator {
fn default() -> Self {
Self::Singer {
eps_rel: Float::EPSILON.powf(0.25),
}
}
}
impl<P, U, E> Terminator<NelderMead, P, GradientFreeStatus, U, E, NelderMeadConfig>
for NelderMeadXTerminator
where
P: CostFunction<U, E>,
{
fn check_for_termination(
&mut self,
_current_step: usize,
algorithm: &mut NelderMead,
_problem: &P,
status: &mut GradientFreeStatus,
_args: &U,
_config: &NelderMeadConfig,
) -> ControlFlow<()> {
let simplex = &algorithm.simplex;
match self {
Self::Diameter { eps_abs: eps_x_abs } => {
let l = simplex.best();
let max_inf_norm = simplex
.points
.iter()
.rev()
.skip(1) .map(|point| {
let diff = &point.x - &l.x;
let mut inf_norm = 0.0;
for i in 0..diff.len() {
if inf_norm < Float::abs(diff[i]) {
inf_norm = Float::abs(diff[i])
}
}
inf_norm
})
.max_by(|&a, &b| a.total_cmp(&b))
.unwrap_or(0.0);
if max_inf_norm <= *eps_x_abs {
status
.set_message()
.succeed_with_message("term_x = DIAMETER");
return ControlFlow::Break(());
}
}
Self::Higham { eps_rel: eps_x_rel } => {
let l = simplex.best();
let l1_norm_l = l.x.lp_norm(1);
let denom = Float::max(l1_norm_l, 1.0);
let numer = simplex
.points
.iter()
.rev()
.skip(1)
.map(|point| {
let diff = &point.x - &l.x;
diff.lp_norm(1)
})
.max_by(|&a, &b| a.total_cmp(&b))
.unwrap_or(0.0);
if numer / denom <= *eps_x_rel {
status.set_message().succeed_with_message("term_x = HIGHAM");
return ControlFlow::Break(());
}
}
Self::Rowan { eps_rel: eps_x_rel } => {
let init_diff = (&simplex.initial_worst.x - &simplex.initial_best.x).lp_norm(2);
let current_diff = (&simplex.worst().x - &simplex.best().x).lp_norm(2);
if current_diff <= *eps_x_rel * init_diff {
status.set_message().succeed_with_message("term_x = ROWAN");
return ControlFlow::Break(());
}
}
Self::Singer { eps_rel: eps_x_rel } => {
let dim = simplex.dimension as Float;
let lv_init = Float::powf(simplex.initial_volume, 1.0 / dim);
let lv_current = Float::powf(simplex.volume, 1.0 / dim);
if lv_current <= *eps_x_rel * lv_init {
status.set_message().succeed_with_message("term_x = SINGER");
return ControlFlow::Break(());
}
}
}
ControlFlow::Continue(())
}
}
#[derive(Clone)]
pub struct NelderMeadConfig {
bounds: Option<Bounds>,
bounds_handling: BoundsHandlingMode,
parameter_names: Option<Vec<String>>,
transform: Option<Box<dyn Transform>>,
alpha: Float,
beta: Float,
gamma: Float,
delta: Float,
expansion_method: SimplexExpansionMethod,
}
impl NelderMeadConfig {
pub fn new() -> Self {
Self::default()
}
pub fn with_alpha(mut self, value: Float) -> GaneshResult<Self> {
if value <= 0.0 {
return Err(GaneshError::ConfigError(
"Reflection coefficient alpha must be greater than 0".to_string(),
));
}
self.alpha = value;
Ok(self)
}
pub fn with_beta(mut self, value: Float) -> GaneshResult<Self> {
if value <= 1.0 {
return Err(GaneshError::ConfigError(
"Expansion coefficient beta must be greater than 1".to_string(),
));
}
if value <= self.alpha {
return Err(GaneshError::ConfigError(format!(
"Expansion coefficient beta must be greater than reflection coefficient alpha ({})",
self.alpha
)));
}
self.beta = value;
Ok(self)
}
pub fn with_alpha_beta(mut self, alpha: Float, beta: Float) -> GaneshResult<Self> {
if alpha <= 0.0 {
return Err(GaneshError::ConfigError(
"Reflection coefficient alpha must be greater than 0".to_string(),
));
}
if beta <= 1.0 {
return Err(GaneshError::ConfigError(
"Expansion coefficient beta must be greater than 1".to_string(),
));
}
if beta <= alpha {
return Err(GaneshError::ConfigError(
"Expansion coefficient beta must be greater than reflection coefficient alpha"
.to_string(),
));
}
self.alpha = alpha;
self.beta = beta;
Ok(self)
}
pub fn with_gamma(mut self, value: Float) -> GaneshResult<Self> {
if value >= 1.0 || value <= 0.0 {
return Err(GaneshError::ConfigError(
"Contraction coefficient gamma must be in (0, 1)".to_string(),
));
}
self.gamma = value;
Ok(self)
}
pub fn with_delta(mut self, value: Float) -> GaneshResult<Self> {
if value >= 1.0 || value <= 0.0 {
return Err(GaneshError::ConfigError(
"Shrink coefficient delta must be in (0, 1)".to_string(),
));
}
self.delta = value;
Ok(self)
}
pub fn with_adaptive(mut self, n: usize) -> GaneshResult<Self> {
if n < 1 {
return Err(GaneshError::ConfigError(
"Adaptive hyperparameters requires input dimension >= 1".to_string(),
));
}
let n = n as Float;
self.alpha = 1.0;
self.beta = 1.0 + (2.0 / n);
self.gamma = 0.75 - 1.0 / (2.0 * n);
self.delta = 1.0 - 1.0 / n;
Ok(self)
}
pub const fn with_expansion_method(mut self, method: SimplexExpansionMethod) -> Self {
self.expansion_method = method;
self
}
pub const fn with_bounds_handling(mut self, bounds_handling: BoundsHandlingMode) -> Self {
self.bounds_handling = bounds_handling;
self
}
}
impl Default for NelderMeadConfig {
fn default() -> Self {
Self {
bounds: None,
bounds_handling: BoundsHandlingMode::default(),
parameter_names: None,
transform: None,
alpha: 1.0,
beta: 2.0,
gamma: 0.5,
delta: 0.5,
expansion_method: SimplexExpansionMethod::default(),
}
}
}
impl SupportsBounds for NelderMeadConfig {
fn get_bounds_mut(&mut self) -> &mut Option<Bounds> {
&mut self.bounds
}
}
impl SupportsTransform for NelderMeadConfig {
fn get_transform_mut(&mut self) -> &mut Option<Box<dyn Transform>> {
&mut self.transform
}
}
impl SupportsParameterNames for NelderMeadConfig {
fn get_parameter_names_mut(&mut self) -> &mut Option<Vec<String>> {
&mut self.parameter_names
}
}
#[derive(Clone, Default)]
pub struct NelderMead {
simplex: Simplex,
internal_bounds: Option<Bounds>,
resolved_transform: Option<Box<dyn Transform>>,
initial_x0: DVector<Float>,
}
#[derive(Clone, Serialize, Deserialize)]
pub struct NelderMeadCheckpoint {
pub simplex: Simplex,
pub initial_x0: DVector<Float>,
pub status: GradientFreeStatus,
pub next_step: usize,
}
impl<P, U, E> Algorithm<P, GradientFreeStatus, U, E> for NelderMead
where
P: CostFunction<U, E>,
{
type Summary = MinimizationSummary;
type Config = NelderMeadConfig;
type Init = NelderMeadInit;
fn initialize(
&mut self,
problem: &P,
status: &mut GradientFreeStatus,
args: &U,
init: &Self::Init,
config: &Self::Config,
) -> Result<(), E> {
let (bounds, transform): (Option<Bounds>, Option<Box<dyn Transform>>) =
resolve_bounds_and_transform(&config.bounds, &config.transform, config.bounds_handling);
let internal_bounds = bounds.map(|b| b.apply(&transform));
self.internal_bounds = internal_bounds;
self.resolved_transform = transform;
self.simplex = init.construction_method.generate(
problem,
&self.resolved_transform,
self.internal_bounds.as_ref(),
args,
)?;
status.n_f_evals += self.simplex.size();
self.initial_x0 = init.starting_point();
status.initialize(self.simplex.best_position(&self.resolved_transform));
Ok(())
}
fn step(
&mut self,
_current_step: usize,
problem: &P,
status: &mut GradientFreeStatus,
args: &U,
config: &Self::Config,
) -> Result<(), E> {
let h = self.simplex.worst();
let s = self.simplex.second_worst();
let l = self.simplex.best();
let c = &self.simplex.corrected_centroid();
let mut xrx = c + (c - &h.x).scale(config.alpha);
if let Some(ib) = self.internal_bounds.as_ref() {
xrx = ib.clip_values(&xrx);
}
let mut xr = Point::from(xrx);
xr.evaluate_transformed(problem, &self.resolved_transform, args)?;
status.inc_n_f_evals();
if l <= &xr && &xr < s {
self.simplex.insert_and_sort(self.simplex.dimension - 2, xr);
status.set_message().step_with_message("REFLECT");
self.simplex.scale_volume(config.alpha);
return Ok(());
} else if &xr < l {
let mut xex = c + (&xr.x - c).scale(config.beta);
if let Some(ib) = self.internal_bounds.as_ref() {
xex = ib.clip_values(&xex);
}
let mut xe = Point::from(xex);
xe.evaluate_transformed(problem, &self.resolved_transform, args)?;
status.inc_n_f_evals();
let accepted = match config.expansion_method {
SimplexExpansionMethod::GreedyMinimization => {
if xe < xr {
xe
} else {
xr
}
}
SimplexExpansionMethod::GreedyExpansion => xe,
};
let accepted_fx = accepted.fx_checked();
let accepted_x = self.resolved_transform.to_owned_external(&accepted.x);
self.simplex.insert_sorted(0, accepted);
status.set_position_silent((accepted_x, accepted_fx));
status.set_message().step_with_message("EXPAND");
self.simplex.scale_volume(config.alpha * config.beta);
return Ok(());
} else if s <= &xr {
if &xr < h {
let mut xcx = c + (&xr.x - c).scale(config.gamma);
if let Some(ib) = self.internal_bounds.as_ref() {
xcx = ib.clip_values(&xcx);
}
let mut xc = Point::from(xcx);
xc.evaluate_transformed(problem, &self.resolved_transform, args)?;
status.inc_n_f_evals();
if xc <= xr {
if &xc < s {
let xc_is_new_best = &xc < l;
let xc_fx = xc.fx_checked();
let xc_x = xc_is_new_best
.then(|| self.resolved_transform.to_owned_external(&xc.x));
self.simplex.insert_and_sort(self.simplex.dimension - 1, xc);
if let Some(xc_x) = xc_x {
status.set_position_silent((xc_x, xc_fx));
}
} else {
self.simplex.insert_sorted(self.simplex.dimension - 1, xc);
}
status.set_message().step_with_message("CONTRACT OUT");
self.simplex.scale_volume(config.alpha * config.gamma);
return Ok(());
}
} else {
let mut xcx = c + (&h.x - c).scale(config.gamma);
if let Some(ib) = self.internal_bounds.as_ref() {
xcx = ib.clip_values(&xcx);
}
let mut xc = Point::from(xcx);
xc.evaluate_transformed(problem, &self.resolved_transform, args)?;
status.inc_n_f_evals();
if &xc < h {
if &xc < s {
let xc_is_new_best = &xc < l;
let xc_fx = xc.fx_checked();
let xc_x = xc_is_new_best
.then(|| self.resolved_transform.to_owned_external(&xc.x));
self.simplex.insert_and_sort(self.simplex.dimension - 1, xc);
if let Some(xc_x) = xc_x {
status.set_position_silent((xc_x, xc_fx));
}
} else {
self.simplex.insert_sorted(self.simplex.dimension - 1, xc);
}
status.set_message().step_with_message("CONTRACT IN");
self.simplex.scale_volume(config.gamma);
return Ok(());
}
}
}
let l_clone = l.clone();
for p in self.simplex.points.iter_mut().skip(1) {
let mut px = &l_clone.x + (&p.x - &l_clone.x).scale(config.delta);
if let Some(ib) = self.internal_bounds.as_ref() {
px = ib.clip_values(&px);
}
*p = Point::from(px);
p.evaluate_transformed(problem, &self.resolved_transform, args)?;
status.inc_n_f_evals();
}
self.simplex.sorted = false;
self.simplex.sort();
self.simplex.compute_total_centroid();
status.set_position_silent(self.simplex.best_position(&self.resolved_transform));
status.set_message().step_with_message("SHRINK");
self.simplex
.scale_volume(Float::powi(config.delta, self.simplex.dimension as i32 - 1));
Ok(())
}
fn summarize(
&self,
_current_step: usize,
_func: &P,
status: &GradientFreeStatus,
_args: &U,
_init: &Self::Init,
config: &Self::Config,
) -> Result<MinimizationSummary, E> {
Ok(MinimizationSummary {
x0: self.initial_x0.clone(),
x: status.x.clone(),
fx: status.fx,
bounds: config.bounds.clone(),
n_f_evals: status.n_f_evals,
n_g_evals: 0,
n_h_evals: 0,
message: status.message.clone(),
parameter_names: config.parameter_names.clone(),
std: status
.err
.clone()
.unwrap_or_else(|| DVector::from_element(status.x.len(), 0.0)),
covariance: status
.cov
.clone()
.unwrap_or_else(|| DMatrix::identity(status.x.len(), status.x.len())),
})
}
fn default_callbacks() -> Callbacks<Self, P, GradientFreeStatus, U, E, Self::Config>
where
Self: Sized,
{
Callbacks::empty()
.with_terminator(NelderMeadFTerminator::default())
.with_terminator(NelderMeadXTerminator::default())
}
fn reset(&mut self) {
self.simplex = Simplex::default();
self.initial_x0 = DVector::default();
}
}
impl<P, U, E> CheckpointableAlgorithm<P, GradientFreeStatus, U, E> for NelderMead
where
P: CostFunction<U, E>,
{
type Checkpoint = NelderMeadCheckpoint;
fn checkpoint(&self, status: &GradientFreeStatus, next_step: usize) -> Self::Checkpoint {
NelderMeadCheckpoint {
simplex: self.simplex.clone(),
initial_x0: self.initial_x0.clone(),
status: status.clone(),
next_step,
}
}
fn restore(
&mut self,
checkpoint: &Self::Checkpoint,
config: &Self::Config,
) -> (GradientFreeStatus, usize) {
let (bounds, transform): (Option<Bounds>, Option<Box<dyn Transform>>) =
resolve_bounds_and_transform(&config.bounds, &config.transform, config.bounds_handling);
self.internal_bounds = bounds.map(|b| b.apply(&transform));
self.resolved_transform = transform;
self.simplex = checkpoint.simplex.clone();
self.initial_x0 = checkpoint.initial_x0.clone();
(checkpoint.status.clone(), checkpoint.next_step)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{
core::{AtomicCheckpointSignal, CheckpointOnSignal, CheckpointStore, MaxSteps},
test_functions::Rosenbrock,
traits::{AbortSignal, CheckpointableAlgorithm, Observer},
};
use approx::assert_relative_eq;
use nalgebra::dvector;
use std::convert::Infallible;
#[derive(Clone)]
struct TriggerAbortAtStep<Sig> {
target_step: usize,
signal: Sig,
}
impl<Sig> TriggerAbortAtStep<Sig> {
fn new(target_step: usize, signal: Sig) -> Self {
Self {
target_step,
signal,
}
}
}
impl<P, U, E, Sig> Observer<NelderMead, P, GradientFreeStatus, U, E, NelderMeadConfig>
for TriggerAbortAtStep<Sig>
where
P: CostFunction<U, E>,
Sig: AbortSignal + Clone,
{
fn observe(
&mut self,
current_step: usize,
_algorithm: &NelderMead,
_problem: &P,
_status: &GradientFreeStatus,
_args: &U,
_config: &NelderMeadConfig,
) {
if current_step == self.target_step {
self.signal.abort();
}
}
}
#[test]
fn test_nelder_mead() {
let mut solver = NelderMead::default();
let problem = Rosenbrock { n: 2 };
let starting_values = vec![
[-2.0, 2.0],
[2.0, 2.0],
[2.0, -2.0],
[-2.0, -2.0],
[1.0, 1.0],
[0.0, 0.0],
];
for starting_value in starting_values {
let result = solver
.process(
&problem,
&(),
NelderMeadInit::new(starting_value),
NelderMeadConfig::default(),
NelderMead::default_callbacks().with_terminator(MaxSteps(1_000_000)),
)
.unwrap();
assert!(result.message.success());
assert_relative_eq!(result.fx, 0.0, epsilon = Float::EPSILON.powf(0.2));
}
}
#[test]
fn nelder_mead_checkpoint_signal_resume_matches_uninterrupted_run() {
let problem = Rosenbrock { n: 2 };
let init = NelderMeadInit::new([2.0, 2.0]);
let config = NelderMeadConfig::default();
let uninterrupted = NelderMead::default()
.process(
&problem,
&(),
init.clone(),
config.clone(),
NelderMead::default_callbacks().with_terminator(MaxSteps(200)),
)
.unwrap();
let signal = AtomicCheckpointSignal::new();
let store = CheckpointStore::new();
let store_sink = store.clone();
let checkpointed = NelderMead::default()
.process(
&problem,
&(),
init.clone(),
config.clone(),
NelderMead::default_callbacks()
.with_terminator(MaxSteps(200))
.with_observer(TriggerAbortAtStep::new(4, signal.clone()))
.with_terminator(CheckpointOnSignal::new(signal, move |checkpoint| {
store_sink.save(checkpoint);
})),
)
.unwrap();
assert!(checkpointed.message.text.contains("Checkpoint requested"));
let checkpoint = store.load().unwrap();
let checkpoint_json = serde_json::to_string(&checkpoint).unwrap();
let checkpoint: NelderMeadCheckpoint = serde_json::from_str(&checkpoint_json).unwrap();
let resumed = NelderMead::default()
.process_from_checkpoint(
&problem,
&(),
init,
config,
&checkpoint,
NelderMead::default_callbacks().with_terminator(MaxSteps(200)),
)
.unwrap();
assert_relative_eq!(
resumed.fx,
uninterrupted.fx,
epsilon = Float::EPSILON.powf(0.2)
);
assert_relative_eq!(
resumed.x[0],
uninterrupted.x[0],
epsilon = Float::EPSILON.powf(0.2)
);
assert_relative_eq!(
resumed.x[1],
uninterrupted.x[1],
epsilon = Float::EPSILON.powf(0.2)
);
assert_eq!(resumed.n_f_evals, uninterrupted.n_f_evals);
}
#[test]
fn test_bounded_nelder_mead() {
let mut solver = NelderMead::default();
let problem = Rosenbrock { n: 2 };
let starting_values = vec![
[-2.0, 2.0],
[2.0, 2.0],
[2.0, -2.0],
[-2.0, -2.0],
[1.0, 1.0],
[0.0, 0.0],
];
for starting_value in starting_values {
let result = solver
.process(
&problem,
&(),
NelderMeadInit::new(starting_value),
NelderMeadConfig::default().with_bounds([(-4.0, 4.0), (-4.0, 4.0)]),
NelderMead::default_callbacks().with_terminator(MaxSteps(1_000_000)),
)
.unwrap();
assert!(result.message.success());
assert_relative_eq!(result.fx, 0.0, epsilon = Float::EPSILON.powf(0.2));
}
}
#[test]
fn test_transformed_nelder_mead() {
let mut solver = NelderMead::default();
let problem = Rosenbrock { n: 2 };
let starting_values = vec![
[-2.0, 2.0],
[2.0, 2.0],
[2.0, -2.0],
[-2.0, -2.0],
[1.0, 1.0],
[0.0, 0.0],
];
for starting_value in starting_values {
let result = solver
.process(
&problem,
&(),
NelderMeadInit::new(starting_value),
NelderMeadConfig::default()
.with_transform(&Bounds::from([(-4.0, 4.0), (-4.0, 4.0)])),
NelderMead::default_callbacks().with_terminator(MaxSteps(1_000_000)),
)
.unwrap();
assert!(result.message.success());
assert_relative_eq!(result.fx, 0.0, epsilon = Float::EPSILON.powf(0.2));
}
}
#[test]
fn test_adaptive_nelder_mead() {
let mut solver = NelderMead::default();
let problem = Rosenbrock { n: 2 };
let starting_values = vec![
[-2.0, 2.0],
[2.0, 2.0],
[2.0, -2.0],
[-2.0, -2.0],
[1.0, 1.0],
[0.0, 0.0],
];
for starting_value in starting_values {
let result = solver
.process(
&problem,
&(),
NelderMeadInit::new(starting_value),
NelderMeadConfig::default().with_adaptive(2).unwrap(),
NelderMead::default_callbacks().with_terminator(MaxSteps(1_000_000)),
)
.unwrap();
assert!(result.message.success());
assert_relative_eq!(result.fx, 0.0, epsilon = Float::EPSILON.powf(0.2));
}
}
fn point(x: &[Float], fx: Float) -> Point<DVector<Float>> {
Point {
x: DVector::from_column_slice(x),
fx: Some(fx),
}
}
#[test]
fn test_corrected_centroid() {
let pts = vec![
point(&[1.0, 2.0], 1.0),
point(&[2.0, 3.0], 2.0),
point(&[3.0, 4.0], 3.0),
];
let simplex = Simplex::new(&pts);
let expected = (&pts[0].x + &pts[1].x) / 2.0;
let actual = simplex.corrected_centroid();
assert_eq!(actual, expected);
}
#[test]
fn test_insert_sorted() {
let mut simplex = Simplex::new(&[
point(&[0.0, 0.0], 0.0),
point(&[1.0, 1.0], 1.0),
point(&[2.0, 2.0], 2.0),
]);
let original_total = simplex.total_centroid.clone();
let new_point = point(&[3.0, 3.0], 1.5);
simplex.insert_sorted(1, new_point.clone());
let expected_total = &original_total + (&new_point.x - &point(&[2.0, 2.0], 2.0).x) / 3.0;
assert_eq!(simplex.total_centroid.clone(), expected_total);
}
#[test]
fn test_insert_and_sort() {
let mut simplex = Simplex::new(&[
point(&[5.0, 0.0], 5.0),
point(&[1.0, 1.0], 1.0),
point(&[2.0, 2.0], 2.0),
]);
let original_total = simplex.total_centroid.clone();
let new_point = point(&[0.5, 0.5], 0.2);
simplex.insert_and_sort(0, new_point.clone());
let expected_total = &original_total + (&new_point.x - &point(&[5.0, 0.0], 5.0).x) / 3.0;
assert_eq!(simplex.best(), &new_point);
assert_eq!(simplex.total_centroid.clone(), expected_total);
}
#[test]
fn terminates_with_f_amoeba() {
let mut solver = NelderMead::default();
let problem = Rosenbrock { n: 2 };
let cfg = NelderMeadConfig::default();
let callbacks =
Callbacks::empty().with_terminator(NelderMeadFTerminator::Amoeba { eps_rel: 0.01 });
let result = solver
.process(
&problem,
&(),
NelderMeadInit::new([0.5, -0.5]),
cfg,
callbacks,
)
.unwrap();
assert!(result.message.success());
assert!(result.message.to_string().contains("term_f = AMOEBA"));
}
#[test]
fn terminates_with_f_absolute() {
let mut solver = NelderMead::default();
let problem = Rosenbrock { n: 2 };
let cfg = NelderMeadConfig::default();
let callbacks = Callbacks::empty().with_terminator(NelderMeadFTerminator::Absolute {
eps_abs: Float::EPSILON.powf(0.25),
});
let result = solver
.process(
&problem,
&(),
NelderMeadInit::new([0.5, -0.5]),
cfg,
callbacks,
)
.unwrap();
assert!(result.message.success());
assert!(result.message.to_string().contains("term_f = ABSOLUTE"));
}
#[test]
fn terminates_with_f_stddev() {
let mut solver = NelderMead::default();
let problem = Rosenbrock { n: 2 };
let cfg = NelderMeadConfig::default();
let callbacks = Callbacks::empty().with_terminator(NelderMeadFTerminator::StdDev {
eps_abs: Float::EPSILON.powf(0.25),
});
let result = solver
.process(
&problem,
&(),
NelderMeadInit::new([0.5, -0.5]),
cfg,
callbacks,
)
.unwrap();
assert!(result.message.success());
assert!(result.message.to_string().contains("term_f = STDDEV"));
}
#[test]
fn terminates_with_x_diameter() {
let mut solver = NelderMead::default();
let problem = Rosenbrock { n: 2 };
let cfg = NelderMeadConfig::default();
let callbacks = Callbacks::empty().with_terminator(NelderMeadXTerminator::Diameter {
eps_abs: Float::EPSILON.powf(0.25),
});
let result = solver
.process(
&problem,
&(),
NelderMeadInit::new([0.5, -0.5]),
cfg,
callbacks,
)
.unwrap();
assert!(result.message.success());
assert!(result.message.to_string().contains("term_x = DIAMETER"));
}
#[test]
fn terminates_with_x_higham() {
let mut solver = NelderMead::default();
let problem = Rosenbrock { n: 2 };
let cfg = NelderMeadConfig::default();
let callbacks = Callbacks::empty().with_terminator(NelderMeadXTerminator::Higham {
eps_rel: Float::EPSILON.powf(0.25),
});
let result = solver
.process(
&problem,
&(),
NelderMeadInit::new([0.5, -0.5]),
cfg,
callbacks,
)
.unwrap();
assert!(result.message.success());
assert!(result.message.to_string().contains("term_x = HIGHAM"));
}
#[test]
fn terminates_with_x_rowan() {
let mut solver = NelderMead::default();
let problem = Rosenbrock { n: 2 };
let cfg = NelderMeadConfig::default();
let callbacks = Callbacks::empty().with_terminator(NelderMeadXTerminator::Rowan {
eps_rel: Float::EPSILON.powf(0.25),
});
let result = solver
.process(
&problem,
&(),
NelderMeadInit::new([0.5, -0.5]),
cfg,
callbacks,
)
.unwrap();
assert!(result.message.success());
assert!(result.message.to_string().contains("term_x = ROWAN"));
}
#[test]
fn terminates_with_x_singer() {
let mut solver = NelderMead::default();
let problem = Rosenbrock { n: 2 };
let cfg = NelderMeadConfig::default();
let callbacks = Callbacks::empty().with_terminator(NelderMeadXTerminator::Singer {
eps_rel: Float::EPSILON.powf(0.25),
});
let result = solver
.process(
&problem,
&(),
NelderMeadInit::new([0.5, -0.5]),
cfg,
callbacks,
)
.unwrap();
assert!(result.message.success());
assert!(result.message.to_string().contains("term_x = SINGER"));
}
#[test]
fn simplex_total_centroid_matches_mean() {
let simplex = Simplex::new(&[
point(&[0.0, 0.0], 3.0),
point(&[2.0, 0.0], 2.0),
point(&[0.0, 2.0], 1.0),
]);
let expected =
(&point(&[0.0, 0.0], 0.0).x + &point(&[2.0, 0.0], 0.0).x + &point(&[0.0, 2.0], 0.0).x)
/ 3.0;
assert_relative_eq!(simplex.total_centroid, expected);
}
#[test]
fn simplex_scale_volume_multiplies() {
let mut simplex = Simplex::new(&[
point(&[0.0, 0.0], 3.0),
point(&[2.0, 0.0], 2.0),
point(&[0.0, 2.0], 1.0),
]);
let v0 = simplex.volume;
simplex.scale_volume(2.5);
assert_relative_eq!(simplex.volume, v0 * 2.5);
}
#[test]
fn diameter_terminator_uses_best_point_as_reference() {
let mut solver = NelderMead {
simplex: Simplex::new(&[
point(&[1.0, 1.0], 0.0),
point(&[0.0, 0.0], 1.0),
point(&[2.0, 2.0], 2.0),
]),
..Default::default()
};
let mut status = GradientFreeStatus::default();
let terminated = NelderMeadXTerminator::Diameter { eps_abs: 1.5 }.check_for_termination(
0,
&mut solver,
&Rosenbrock { n: 2 },
&mut status,
&(),
&NelderMeadConfig::default(),
);
assert!(terminated.is_break());
assert!(status.message.to_string().contains("term_x = DIAMETER"));
}
#[test]
fn higham_terminator_uses_best_point_as_reference() {
let mut solver = NelderMead {
simplex: Simplex::new(&[
point(&[10.0, 10.0], 0.0),
point(&[9.5, 10.0], 1.0),
point(&[0.0, 0.0], 2.0),
]),
..Default::default()
};
let mut status = GradientFreeStatus::default();
let terminated = NelderMeadXTerminator::Higham { eps_rel: 2.0 }.check_for_termination(
0,
&mut solver,
&Rosenbrock { n: 2 },
&mut status,
&(),
&NelderMeadConfig::default(),
);
assert!(terminated.is_break());
assert!(status.message.to_string().contains("term_x = HIGHAM"));
}
#[test]
fn shrink_volume_uses_parameter_dimension_exponent() {
let mut simplex = Simplex::new(&[
point(&[0.0, 0.0], 3.0),
point(&[2.0, 0.0], 2.0),
point(&[0.0, 2.0], 1.0),
]);
let delta = 0.5;
let v0 = simplex.volume;
simplex.scale_volume(Float::powi(delta, simplex.dimension as i32 - 1));
assert_relative_eq!(simplex.volume, v0 * Float::powi(delta, 2));
}
#[test]
#[should_panic(
expected = "Nelder-Mead is only a suitable method for problems of dimension >= 2"
)]
fn orthogonal_simplex_panics_in_1d() {
let method = SimplexConstructionMethod::ScaledOrthogonal {
x0: DVector::from_element(1, 1.0),
orthogonal_multiplier: 1.05,
orthogonal_zero_step: 0.00025,
};
let problem = Rosenbrock { n: 1 };
let _ = method
.generate::<_, Infallible>(&problem, &None, None, &())
.unwrap();
}
#[test]
fn custom_simplex_ignores_x0_and_sorts_by_fx() {
let method = SimplexConstructionMethod::Custom {
simplex: vec![dvector![2.0, 2.0], dvector![1.0, 1.0], dvector![0.0, 0.0]],
};
let problem = Rosenbrock { n: 2 };
let simplex = method
.generate::<_, Infallible>(&problem, &None, None, &())
.unwrap();
assert_relative_eq!(simplex.best().x[0], 1.0);
assert_relative_eq!(simplex.best().x[1], 1.0);
assert!(simplex.best().fx <= simplex.second_worst().fx);
assert!(simplex.second_worst().fx <= simplex.worst().fx);
}
#[test]
fn adaptive_parameters_match_gao_han() {
let cfg = NelderMeadConfig::default().with_adaptive(2).unwrap();
assert_relative_eq!(cfg.alpha, 1.0);
assert_relative_eq!(cfg.beta, 2.0);
assert_relative_eq!(cfg.gamma, 0.5);
assert_relative_eq!(cfg.delta, 0.5);
}
#[test]
fn expansion_and_construction_method_switches_are_accepted() {
let mut solver = NelderMead::default();
let problem = Rosenbrock { n: 2 };
let result = solver
.process(
&problem,
&(),
NelderMeadInit::custom(vec![
dvector![0.5, -0.5],
dvector![1.5, -0.5],
dvector![0.5, 0.5],
])
.unwrap(),
NelderMeadConfig::default()
.with_expansion_method(SimplexExpansionMethod::GreedyExpansion),
NelderMead::default_callbacks(),
)
.unwrap();
assert!(result.message.success());
}
#[test]
fn custom_simplex_config_rejects_invalid_shapes() {
let err = match NelderMeadInit::custom(Vec::<DVector<Float>>::new()) {
Err(err) => err,
Ok(_) => panic!("empty custom simplex should be rejected"),
};
assert!(err.to_string().contains("must not be empty"));
let err = match NelderMeadInit::custom(vec![
dvector![1.0, 2.0],
dvector![3.0],
dvector![4.0, 5.0],
]) {
Err(err) => err,
Ok(_) => panic!("mixed-dimension custom simplex should be rejected"),
};
assert!(err.to_string().contains("same dimension"));
let err = match NelderMeadInit::custom(vec![dvector![1.0, 2.0], dvector![3.0, 4.0]]) {
Err(err) => err,
Ok(_) => panic!("wrong-size custom simplex should be rejected"),
};
assert!(err.to_string().contains("exactly n + 1 points"));
}
#[test]
#[should_panic]
fn with_alpha_panics_on_nonpositive() {
let _ = NelderMeadConfig::default().with_alpha(0.0).unwrap();
}
#[test]
#[should_panic]
fn with_beta_panics_when_not_gt_one() {
let _ = NelderMeadConfig::default().with_beta(1.0).unwrap();
}
#[test]
fn with_alpha_beta_sets_values() {
let nmc = NelderMeadConfig::default()
.with_alpha_beta(1.1, 2.2)
.unwrap();
assert_eq!(nmc.alpha, 1.1);
assert_eq!(nmc.beta, 2.2);
}
#[test]
#[should_panic]
fn with_alpha_beta_panics_when_alpha_nonpositive() {
let _ = NelderMeadConfig::default()
.with_alpha_beta(0.0, 2.0)
.unwrap();
}
#[test]
#[should_panic]
fn with_alpha_beta_panics_when_beta_not_gt_one() {
let _ = NelderMeadConfig::default()
.with_alpha_beta(0.5, 1.0)
.unwrap();
}
#[test]
#[should_panic]
fn with_alpha_beta_panics_when_beta_not_gt_alpha() {
let _ = NelderMeadConfig::default()
.with_alpha_beta(1.6, 1.5)
.unwrap();
}
#[test]
#[should_panic]
fn with_beta_panics_when_not_gt_alpha() {
let _ = NelderMeadConfig::default()
.with_alpha(1.5)
.unwrap()
.with_beta(1.4)
.unwrap();
}
#[test]
#[should_panic]
fn with_gamma_panics_if_not_in_unit() {
let _ = NelderMeadConfig::default().with_gamma(0.0).unwrap();
}
#[test]
#[should_panic]
fn with_delta_panics_if_not_in_unit() {
let _ = NelderMeadConfig::default().with_delta(1.0).unwrap();
}
#[test]
fn check_bounds_and_num_gradient_evals() {
let mut solver = NelderMead::default();
let problem = Rosenbrock { n: 2 };
let result = solver
.process(
&problem,
&(),
NelderMeadInit::new([-3.0, 3.0]),
NelderMeadConfig::default()
.with_transform(&Bounds::from([(-4.0, 4.0), (-4.0, 4.0)])),
NelderMead::default_callbacks().with_terminator(MaxSteps(200_000)),
)
.unwrap();
assert!(result.message.success());
assert_eq!(result.n_g_evals, 0);
}
#[test]
fn summary_reports_simplex_init_evals_and_terminal_message() {
let mut solver = NelderMead::default();
let problem = Rosenbrock { n: 2 };
let result = solver
.process(
&problem,
&(),
NelderMeadInit::new([0.5, -0.5]),
NelderMeadConfig::default(),
Callbacks::empty().with_terminator(MaxSteps(2)),
)
.unwrap();
assert!(result.n_f_evals >= 3);
assert_eq!(result.n_g_evals, 0);
assert!(result
.message
.to_string()
.contains("Maximum number of steps reached"));
}
#[test]
fn transform_bounds_mode_is_selectable_for_nelder_mead() {
let config = NelderMeadConfig::default()
.with_bounds([(0.0, 1.0), (0.0, 1.0)])
.with_bounds_handling(BoundsHandlingMode::TransformBounds);
assert!(matches!(
config.bounds_handling,
BoundsHandlingMode::TransformBounds
));
}
}