use num_traits::Float;
use std::fmt::Debug;
use crate::math::optimization::{ObjectiveFunction, OptimizationConfig, OptimizationResult};
pub fn minimize<T, F>(
f: &F,
initial_point: &[T],
config: &OptimizationConfig<T>,
) -> OptimizationResult<T>
where
T: Float + Debug,
F: ObjectiveFunction<T>,
{
let alpha = T::from(1.0).unwrap(); let gamma = T::from(2.0).unwrap(); let rho = T::from(0.5).unwrap(); let sigma = T::from(0.5).unwrap();
let n = initial_point.len();
let mut iterations = 0;
let mut converged = false;
let mut simplex = initialize_simplex(initial_point);
let mut values = evaluate_simplex(f, &simplex);
let mut best_value = values[0];
let mut best_point = simplex[0].clone();
while iterations < config.max_iterations {
order_simplex(&mut simplex, &mut values);
if values[0] < best_value {
best_value = values[0];
best_point = simplex[0].clone();
}
let centroid = compute_centroid(&simplex[..n]);
let size_measure = compute_simplex_size(&simplex);
let value_range = values[n] - values[0];
if size_measure < config.tolerance && value_range < config.tolerance {
converged = true;
break;
}
let reflected = reflect(¢roid, &simplex[n], alpha);
let reflected_value = f.evaluate(&reflected);
if values[0] <= reflected_value && reflected_value < values[n - 1] {
simplex[n] = reflected;
values[n] = reflected_value;
} else if reflected_value < values[0] {
let expanded = reflect(¢roid, &simplex[n], gamma);
let expanded_value = f.evaluate(&expanded);
if expanded_value < reflected_value {
simplex[n] = expanded;
values[n] = expanded_value;
} else {
simplex[n] = reflected;
values[n] = reflected_value;
}
} else {
let contracted = reflect(¢roid, &simplex[n], -rho);
let contracted_value = f.evaluate(&contracted);
if contracted_value < values[n] {
simplex[n] = contracted;
values[n] = contracted_value;
} else {
let mut best = simplex[0].clone();
let mut best_value = values[0];
for i in 1..n + 1 {
if values[i] < best_value {
best = simplex[i].clone();
best_value = values[i];
}
}
for i in 1..=n {
for (j, val) in best.iter_mut().enumerate().take(n) {
*val = centroid[j] + sigma * (centroid[j] - simplex[i][j]);
}
values[i] = f.evaluate(&best);
}
}
}
iterations += 1;
}
OptimizationResult {
optimal_point: best_point,
optimal_value: best_value,
iterations,
converged,
}
}
fn initialize_simplex<T>(initial_point: &[T]) -> Vec<Vec<T>>
where
T: Float + Debug,
{
let n = initial_point.len();
let mut simplex = vec![initial_point.to_vec()];
let scale = T::from(0.1).unwrap(); for i in 0..n {
let mut vertex = initial_point.to_vec();
if vertex[i] == T::zero() {
vertex[i] = scale;
} else {
vertex[i] = vertex[i] * (T::one() + scale);
}
simplex.push(vertex);
}
simplex
}
fn evaluate_simplex<T, F>(f: &F, simplex: &[Vec<T>]) -> Vec<T>
where
T: Float + Debug,
F: ObjectiveFunction<T>,
{
simplex.iter().map(|x| f.evaluate(x)).collect()
}
fn order_simplex<T>(simplex: &mut [Vec<T>], values: &mut [T])
where
T: Float + Debug,
{
let n = values.len() - 1;
for i in 0..n {
for j in 0..n - i {
if values[j] > values[j + 1] {
values.swap(j, j + 1);
simplex.swap(j, j + 1);
}
}
}
}
fn compute_centroid<T>(points: &[Vec<T>]) -> Vec<T>
where
T: Float + Debug,
{
let n = points[0].len();
let m = points.len();
let mut centroid = vec![T::zero(); n];
for i in 0..n {
for point in points.iter() {
centroid[i] = centroid[i] + point[i];
}
centroid[i] = centroid[i] / T::from(m).unwrap();
}
centroid
}
fn reflect<T>(centroid: &[T], point: &[T], coefficient: T) -> Vec<T>
where
T: Float + Debug,
{
centroid
.iter()
.zip(point.iter())
.map(|(&c, &p)| c + coefficient * (c - p))
.collect()
}
fn compute_simplex_size<T>(simplex: &[Vec<T>]) -> T
where
T: Float + Debug,
{
let n = simplex.len();
let mut max_dist = T::zero();
for i in 0..n {
for j in (i + 1)..n {
let dist = simplex[i]
.iter()
.zip(simplex[j].iter())
.fold(T::zero(), |acc, (&x, &y)| acc + (x - y) * (x - y))
.sqrt();
max_dist = max_dist.max(dist);
}
}
max_dist
}
#[cfg(test)]
mod tests {
use super::*;
struct Quadratic;
impl ObjectiveFunction<f64> for Quadratic {
fn evaluate(&self, point: &[f64]) -> f64 {
point.iter().map(|x| x * x).sum()
}
}
#[test]
fn test_nelder_mead_quadratic() {
let f = Quadratic;
let initial_point = vec![1.0, 1.0];
let config = OptimizationConfig {
max_iterations: 100,
tolerance: 1e-6,
learning_rate: 1.0,
};
let result = minimize(&f, &initial_point, &config);
assert!(result.converged);
assert!(result.optimal_value < 1e-10);
for x in result.optimal_point {
assert!(x.abs() < 1e-5);
}
}
struct QuadraticWithMinimum;
impl ObjectiveFunction<f64> for QuadraticWithMinimum {
fn evaluate(&self, point: &[f64]) -> f64 {
let x = point[0];
(x - 2.0).powi(2)
}
}
#[test]
fn test_nelder_mead_quadratic_with_minimum() {
let f = QuadraticWithMinimum;
let initial_point = vec![0.0];
let config = OptimizationConfig {
max_iterations: 100,
tolerance: 1e-6,
learning_rate: 1.0,
};
let result = minimize(&f, &initial_point, &config);
assert!(result.converged);
assert!((result.optimal_point[0] - 2.0).abs() < 1e-5);
assert!(result.optimal_value < 1e-10);
}
struct Rosenbrock;
impl ObjectiveFunction<f64> for Rosenbrock {
fn evaluate(&self, point: &[f64]) -> f64 {
let x = point[0];
let y = point[1];
(x - 1.0).powi(2) + 100.0 * (y - x.powi(2)).powi(2)
}
}
#[test]
fn test_nelder_mead_rosenbrock() {
let f = Rosenbrock;
let initial_point = vec![0.0, 0.0];
let config = OptimizationConfig {
max_iterations: 1000,
tolerance: 1e-6,
learning_rate: 1.0,
};
let result = minimize(&f, &initial_point, &config);
assert!(result.converged);
assert!((result.optimal_point[0] - 1.0).abs() < 1e-3);
assert!((result.optimal_point[1] - 1.0).abs() < 1e-3);
}
}