use numr::dtype::DType;
use numr::error::Result;
use numr::ops::{CompareOps, ScalarOps, TensorOps};
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;
use crate::optimize::error::{OptimizeError, OptimizeResult};
use crate::optimize::minimize::MinimizeOptions;
use super::helpers::{TensorMinimizeResult, compare_f64_nan_safe};
use super::utils::SINGULAR_THRESHOLD;
pub fn nelder_mead_impl<R, C, F>(
client: &C,
f: F,
x0: &Tensor<R>,
options: &MinimizeOptions,
) -> OptimizeResult<TensorMinimizeResult<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + ScalarOps<R> + CompareOps<R> + RuntimeClient<R>,
F: Fn(&Tensor<R>) -> Result<f64>,
{
let n = x0.shape()[0];
if n == 0 {
return Err(OptimizeError::InvalidInput {
context: "nelder_mead: empty initial guess".to_string(),
});
}
let alpha = 1.0; let gamma = 2.0; let rho = 0.5; let sigma = 0.5;
let simplex = initialize_simplex(client, x0, n)?;
let mut f_values = Vec::with_capacity(n + 1);
let mut nfev = 0;
for i in 0..=n {
let vertex = extract_row(client, &simplex, i, n)?;
let fval = f(&vertex).map_err(|e| OptimizeError::NumericalError {
message: format!("nelder_mead: initial evaluation - {}", e),
})?;
f_values.push(fval);
nfev += 1;
}
let mut vertices: Vec<Tensor<R>> = Vec::with_capacity(n + 1);
for i in 0..=n {
vertices.push(extract_row(client, &simplex, i, n)?);
}
for iter in 0..options.max_iter {
let mut indices: Vec<usize> = (0..=n).collect();
indices.sort_by(|&a, &b| compare_f64_nan_safe(f_values[a], f_values[b]));
let best_idx = indices[0];
let worst_idx = indices[n];
let second_worst_idx = indices[n - 1];
if f_values[best_idx].is_nan() {
return Err(OptimizeError::NumericalError {
message: "nelder_mead: all function values are NaN".to_string(),
});
}
let f_range = f_values[worst_idx] - f_values[best_idx];
if f_range < options.f_tol {
return Ok(TensorMinimizeResult {
x: vertices[best_idx].clone(),
fun: f_values[best_idx],
iterations: iter + 1,
nfev,
converged: true,
});
}
let centroid = compute_centroid(client, &vertices, &indices[..n])?;
let worst = &vertices[worst_idx];
let reflected = reflect_point(client, ¢roid, worst, alpha)?;
let f_reflected = f(&reflected).map_err(|e| OptimizeError::NumericalError {
message: format!("nelder_mead: reflection - {}", e),
})?;
nfev += 1;
if f_reflected < f_values[second_worst_idx] && f_reflected >= f_values[best_idx] {
vertices[worst_idx] = reflected;
f_values[worst_idx] = f_reflected;
continue;
}
if f_reflected < f_values[best_idx] {
let expanded = reflect_point(client, ¢roid, &reflected, -gamma)?;
let f_expanded = f(&expanded).map_err(|e| OptimizeError::NumericalError {
message: format!("nelder_mead: expansion - {}", e),
})?;
nfev += 1;
if f_expanded < f_reflected {
vertices[worst_idx] = expanded;
f_values[worst_idx] = f_expanded;
} else {
vertices[worst_idx] = reflected;
f_values[worst_idx] = f_reflected;
}
continue;
}
let (contracted, f_contracted) = if f_reflected < f_values[worst_idx] {
let contracted = reflect_point(client, ¢roid, &reflected, -rho)?;
let f_contracted = f(&contracted).map_err(|e| OptimizeError::NumericalError {
message: format!("nelder_mead: outside contraction - {}", e),
})?;
nfev += 1;
(contracted, f_contracted)
} else {
let contracted = reflect_point(client, ¢roid, worst, -rho)?;
let f_contracted = f(&contracted).map_err(|e| OptimizeError::NumericalError {
message: format!("nelder_mead: inside contraction - {}", e),
})?;
nfev += 1;
(contracted, f_contracted)
};
if f_contracted < f_values[worst_idx].min(f_reflected) {
vertices[worst_idx] = contracted;
f_values[worst_idx] = f_contracted;
continue;
}
let best = &vertices[best_idx].clone();
for &idx in &indices[1..=n] {
let diff =
client
.sub(&vertices[idx], best)
.map_err(|e| OptimizeError::NumericalError {
message: format!("nelder_mead: shrink diff - {}", e),
})?;
let scaled =
client
.mul_scalar(&diff, sigma)
.map_err(|e| OptimizeError::NumericalError {
message: format!("nelder_mead: shrink scale - {}", e),
})?;
vertices[idx] =
client
.add(best, &scaled)
.map_err(|e| OptimizeError::NumericalError {
message: format!("nelder_mead: shrink add - {}", e),
})?;
f_values[idx] = f(&vertices[idx]).map_err(|e| OptimizeError::NumericalError {
message: format!("nelder_mead: shrink eval - {}", e),
})?;
nfev += 1;
}
}
let mut best_idx = 0;
for i in 1..=n {
if f_values[i] < f_values[best_idx] {
best_idx = i;
}
}
Ok(TensorMinimizeResult {
x: vertices[best_idx].clone(),
fun: f_values[best_idx],
iterations: options.max_iter,
nfev,
converged: false,
})
}
fn initialize_simplex<R, C>(client: &C, x0: &Tensor<R>, n: usize) -> OptimizeResult<Tensor<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + ScalarOps<R> + CompareOps<R> + RuntimeClient<R>,
{
let abs_x0 = client.abs(x0).map_err(|e| OptimizeError::NumericalError {
message: format!("nelder_mead: abs x0 - {}", e),
})?;
let threshold_tensor = client
.fill(&[n], SINGULAR_THRESHOLD, DType::F64)
.map_err(|e| OptimizeError::NumericalError {
message: format!("nelder_mead: threshold tensor - {}", e),
})?;
let mask_f64 =
client
.gt(&abs_x0, &threshold_tensor)
.map_err(|e| OptimizeError::NumericalError {
message: format!("nelder_mead: gt comparison - {}", e),
})?;
let mask = client
.cast(&mask_f64, DType::U8)
.map_err(|e| OptimizeError::NumericalError {
message: format!("nelder_mead: cast mask - {}", e),
})?;
let large_delta = client
.mul_scalar(x0, 0.05)
.map_err(|e| OptimizeError::NumericalError {
message: format!("nelder_mead: large delta - {}", e),
})?;
let small_delta =
client
.fill(&[n], 0.00025, DType::F64)
.map_err(|e| OptimizeError::NumericalError {
message: format!("nelder_mead: small delta - {}", e),
})?;
let deltas = client
.where_cond(&mask, &large_delta, &small_delta)
.map_err(|e| OptimizeError::NumericalError {
message: format!("nelder_mead: select deltas - {}", e),
})?;
let identity = client
.eye(n, None, DType::F64)
.map_err(|e| OptimizeError::NumericalError {
message: format!("nelder_mead: eye - {}", e),
})?;
let deltas_broadcast = deltas
.unsqueeze(0)
.map_err(|e| OptimizeError::NumericalError {
message: format!("nelder_mead: unsqueeze deltas - {}", e),
})?
.broadcast_to(&[n, n])
.map_err(|e| OptimizeError::NumericalError {
message: format!("nelder_mead: broadcast deltas - {}", e),
})?;
let perturbation =
client
.mul(&identity, &deltas_broadcast)
.map_err(|e| OptimizeError::NumericalError {
message: format!("nelder_mead: perturbation matrix - {}", e),
})?;
let x0_broadcast = x0
.unsqueeze(0)
.map_err(|e| OptimizeError::NumericalError {
message: format!("nelder_mead: unsqueeze x0 - {}", e),
})?
.broadcast_to(&[n, n])
.map_err(|e| OptimizeError::NumericalError {
message: format!("nelder_mead: broadcast x0 - {}", e),
})?;
let perturbed =
client
.add(&x0_broadcast, &perturbation)
.map_err(|e| OptimizeError::NumericalError {
message: format!("nelder_mead: perturbed vertices - {}", e),
})?;
let x0_row = x0
.contiguous()
.map_err(|e| OptimizeError::NumericalError {
message: format!("nelder_mead: contiguous x0 - {}", e),
})?
.unsqueeze(0)
.map_err(|e| OptimizeError::NumericalError {
message: format!("nelder_mead: x0 row - {}", e),
})?;
let perturbed_contig = perturbed
.contiguous()
.map_err(|e| OptimizeError::NumericalError {
message: format!("nelder_mead: contiguous perturbed - {}", e),
})?;
client
.cat(&[&x0_row, &perturbed_contig], 0)
.map_err(|e| OptimizeError::NumericalError {
message: format!("nelder_mead: concat simplex - {}", e),
})
}
fn extract_row<R, C>(
_client: &C,
matrix: &Tensor<R>,
row: usize,
n: usize,
) -> OptimizeResult<Tensor<R>>
where
R: Runtime<DType = DType>,
C: RuntimeClient<R>,
{
matrix
.narrow(0, row, 1)
.map_err(|e| OptimizeError::NumericalError {
message: format!("nelder_mead: narrow row - {}", e),
})?
.contiguous()
.map_err(|e| OptimizeError::NumericalError {
message: format!("nelder_mead: contiguous row - {}", e),
})?
.reshape(&[n])
.map_err(|e| OptimizeError::NumericalError {
message: format!("nelder_mead: reshape row - {}", e),
})
}
fn compute_centroid<R, C>(
client: &C,
vertices: &[Tensor<R>],
indices: &[usize],
) -> OptimizeResult<Tensor<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + ScalarOps<R> + RuntimeClient<R>,
{
let k = indices.len();
if k == 0 {
return Err(OptimizeError::NumericalError {
message: "nelder_mead: empty indices for centroid".to_string(),
});
}
let mut sum = vertices[indices[0]].clone();
for &idx in &indices[1..] {
sum = client
.add(&sum, &vertices[idx])
.map_err(|e| OptimizeError::NumericalError {
message: format!("nelder_mead: centroid sum - {}", e),
})?;
}
client
.mul_scalar(&sum, 1.0 / k as f64)
.map_err(|e| OptimizeError::NumericalError {
message: format!("nelder_mead: centroid div - {}", e),
})
}
fn reflect_point<R, C>(
client: &C,
base: &Tensor<R>,
point: &Tensor<R>,
coeff: f64,
) -> OptimizeResult<Tensor<R>>
where
R: Runtime<DType = DType>,
C: TensorOps<R> + ScalarOps<R> + RuntimeClient<R>,
{
let diff = client
.sub(base, point)
.map_err(|e| OptimizeError::NumericalError {
message: format!("nelder_mead: reflect diff - {}", e),
})?;
let scaled = client
.mul_scalar(&diff, coeff)
.map_err(|e| OptimizeError::NumericalError {
message: format!("nelder_mead: reflect scale - {}", e),
})?;
client
.add(base, &scaled)
.map_err(|e| OptimizeError::NumericalError {
message: format!("nelder_mead: reflect add - {}", e),
})
}