const NM_REFLECT: f64 = 1.0;
const NM_EXPAND: f64 = 2.0;
const NM_CONTRACT: f64 = 0.5;
const NM_SHRINK: f64 = 0.5;
#[inline]
#[must_use]
pub fn index_to_f64(i: usize) -> f64 {
let value = i as u64;
let high = u32::try_from(value >> 32).unwrap_or(u32::MAX);
let low = u32::try_from(value & 0xFFFF_FFFF).unwrap_or(u32::MAX);
f64::from(high) * 4_294_967_296.0 + f64::from(low)
}
#[derive(Debug, Clone, PartialEq)]
pub struct NelderMeadResult {
pub x: Vec<f64>,
pub fx: f64,
pub iterations: usize,
pub converged: bool,
}
#[must_use]
#[allow(clippy::too_many_lines)] pub fn nelder_mead<F>(objective: F, start: &[f64], tol: f64, max_iter: usize) -> NelderMeadResult
where
F: Fn(&[f64]) -> f64,
{
let n = start.len();
if n == 0 {
return NelderMeadResult {
x: Vec::new(),
fx: objective(&[]),
iterations: 0,
converged: true,
};
}
let mut simplex: Vec<Vec<f64>> = Vec::with_capacity(n + 1);
simplex.push(start.to_vec());
for j in 0..n {
let mut v = start.to_vec();
let step = if v[j].abs() > 1e-8 {
0.05 * v[j].abs()
} else {
0.000_25
};
v[j] += step;
simplex.push(v);
}
let mut fvals: Vec<f64> = simplex.iter().map(|v| objective(v)).collect();
let mut order: Vec<usize> = (0..=n).collect();
let mut iterations = 0;
let mut converged = false;
while iterations < max_iter {
iterations += 1;
order.sort_by(|&a, &b| {
fvals[a]
.partial_cmp(&fvals[b])
.unwrap_or(core::cmp::Ordering::Equal)
});
let best = order[0];
let worst = order[n];
let second_worst = order[n - 1];
let spread = (fvals[worst] - fvals[best]).abs();
let mut diameter = 0.0_f64;
for v in &simplex {
let mut dist = 0.0;
for j in 0..n {
let d = v[j] - simplex[best][j];
dist += d * d;
}
diameter = diameter.max(dist.sqrt());
}
if spread <= tol && diameter <= tol.sqrt().max(tol) {
converged = true;
break;
}
let mut centroid = vec![0.0; n];
for (idx, &v) in order.iter().enumerate() {
if idx == n {
continue;
}
for j in 0..n {
centroid[j] += simplex[v][j];
}
}
let inv_n = 1.0 / index_to_f64(n);
for c in &mut centroid {
*c *= inv_n;
}
let reflected = axpy(¢roid, NM_REFLECT, ¢roid, &simplex[worst]);
let f_reflected = objective(&reflected);
if f_reflected < fvals[best] {
let expanded = axpy(¢roid, NM_EXPAND, &reflected, ¢roid);
let f_expanded = objective(&expanded);
if f_expanded < f_reflected {
simplex[worst] = expanded;
fvals[worst] = f_expanded;
} else {
simplex[worst] = reflected;
fvals[worst] = f_reflected;
}
} else if f_reflected < fvals[second_worst] {
simplex[worst] = reflected;
fvals[worst] = f_reflected;
} else {
let (contracted, f_contracted) = if f_reflected < fvals[worst] {
let c = axpy(¢roid, NM_CONTRACT, &reflected, ¢roid);
let fc = objective(&c);
(c, fc)
} else {
let c = axpy(¢roid, NM_CONTRACT, &simplex[worst], ¢roid);
let fc = objective(&c);
(c, fc)
};
if f_contracted < f_reflected.min(fvals[worst]) {
simplex[worst] = contracted;
fvals[worst] = f_contracted;
} else {
let best_pt = simplex[best].clone();
for &v in &order[1..] {
for j in 0..n {
simplex[v][j] = best_pt[j] + NM_SHRINK * (simplex[v][j] - best_pt[j]);
}
fvals[v] = objective(&simplex[v]);
}
}
}
}
order.sort_by(|&a, &b| {
fvals[a]
.partial_cmp(&fvals[b])
.unwrap_or(core::cmp::Ordering::Equal)
});
let best = order[0];
NelderMeadResult {
x: simplex[best].clone(),
fx: fvals[best],
iterations,
converged,
}
}
fn axpy(centroid: &[f64], coeff: f64, point: &[f64], reference: &[f64]) -> Vec<f64> {
centroid
.iter()
.zip(point.iter())
.zip(reference.iter())
.map(|((&c, &p), &r)| c + coeff * (p - r))
.collect()
}
#[must_use]
#[allow(clippy::many_single_char_names)]
pub fn brent_root<F>(f: F, a: f64, b: f64, tol: f64, max_iter: usize) -> Option<f64>
where
F: Fn(f64) -> f64,
{
let mut a = a;
let mut b = b;
let mut fa = f(a);
let mut fb = f(b);
if fa == 0.0 {
return Some(a);
}
if fb == 0.0 {
return Some(b);
}
if fa * fb > 0.0 {
return None;
}
if fa.abs() < fb.abs() {
core::mem::swap(&mut a, &mut b);
core::mem::swap(&mut fa, &mut fb);
}
let mut c = a;
let mut fc = fa;
let mut d = a;
let mut mflag = true;
for _ in 0..max_iter {
if (b - a).abs() <= tol || fb == 0.0 {
return Some(b);
}
let mut s = if (fa - fc).abs() > f64::EPSILON && (fb - fc).abs() > f64::EPSILON {
a * fb * fc / ((fa - fb) * (fa - fc))
+ b * fa * fc / ((fb - fa) * (fb - fc))
+ c * fa * fb / ((fc - fa) * (fc - fb))
} else {
b - fb * (b - a) / (fb - fa)
};
let lo = (3.0 * a + b) / 4.0;
let bound_lo = lo.min(b);
let bound_hi = lo.max(b);
let use_bisection = !(bound_lo..=bound_hi).contains(&s)
|| (mflag && (s - b).abs() >= (b - c).abs() / 2.0)
|| (!mflag && (s - b).abs() >= (c - d).abs() / 2.0)
|| (mflag && (b - c).abs() < tol)
|| (!mflag && (c - d).abs() < tol);
if use_bisection {
s = f64::midpoint(a, b);
mflag = true;
} else {
mflag = false;
}
let fs = f(s);
d = c;
c = b;
fc = fb;
if fa * fs < 0.0 {
b = s;
fb = fs;
} else {
a = s;
fa = fs;
}
if fa.abs() < fb.abs() {
core::mem::swap(&mut a, &mut b);
core::mem::swap(&mut fa, &mut fb);
}
}
Some(b)
}
#[must_use]
pub fn solve_spd_3(a: &[f64; 6], rhs: &[f64; 3]) -> Option<[f64; 3]> {
let mat = [[a[0], a[1], a[2]], [a[1], a[3], a[4]], [a[2], a[4], a[5]]];
let sol = solve_spd(&mat.iter().flatten().copied().collect::<Vec<_>>(), rhs, 3)?;
Some([sol[0], sol[1], sol[2]])
}
#[must_use]
#[allow(clippy::many_single_char_names)]
pub fn solve_spd(a: &[f64], rhs: &[f64], n: usize) -> Option<Vec<f64>> {
if a.len() != n * n || rhs.len() != n {
return None;
}
let mut l = vec![0.0; n * n];
for i in 0..n {
for j in 0..=i {
let mut sum = a[i * n + j];
for k in 0..j {
sum -= l[i * n + k] * l[j * n + k];
}
if i == j {
if sum <= 0.0 || !sum.is_finite() {
return None;
}
l[i * n + j] = sum.sqrt();
} else {
let pivot = l[j * n + j];
if pivot == 0.0 {
return None;
}
l[i * n + j] = sum / pivot;
}
}
}
let mut y = vec![0.0; n];
for i in 0..n {
let mut sum = rhs[i];
for k in 0..i {
sum -= l[i * n + k] * y[k];
}
y[i] = sum / l[i * n + i];
}
let mut x = vec![0.0; n];
for i in (0..n).rev() {
let mut sum = y[i];
for k in (i + 1)..n {
sum -= l[k * n + i] * x[k];
}
x[i] = sum / l[i * n + i];
}
if x.iter().all(|v| v.is_finite()) {
Some(x)
} else {
None
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct LevenbergMarquardtResult {
pub params: Vec<f64>,
pub cost: f64,
pub iterations: usize,
pub converged: bool,
}
#[must_use]
pub fn levenberg_marquardt<F>(
residual: F,
start: &[f64],
tol: f64,
max_iter: usize,
) -> LevenbergMarquardtResult
where
F: Fn(&[f64]) -> Vec<(f64, f64, Vec<f64>)>,
{
let n = start.len();
let mut params = start.to_vec();
let mut mu = 1e-3;
let nu_growth = 2.0;
let cost_of =
|obs: &[(f64, f64, Vec<f64>)]| -> f64 { obs.iter().map(|(r, w, _)| w * r * r).sum() };
let mut obs = residual(¶ms);
let mut cost = cost_of(&obs);
let mut iterations = 0;
let mut converged = false;
while iterations < max_iter {
iterations += 1;
let mut jtj = vec![0.0; n * n];
let mut jtr = vec![0.0; n];
for (r, w, jac) in &obs {
for i in 0..n {
jtr[i] -= w * jac[i] * r;
for j in 0..n {
jtj[i * n + j] += w * jac[i] * jac[j];
}
}
}
let grad_norm = jtr.iter().map(|g| g * g).sum::<f64>().sqrt();
if grad_norm <= tol || cost <= tol * tol {
converged = true;
break;
}
let mut accepted = false;
for _ in 0..30 {
let mut damped = jtj.clone();
for i in 0..n {
damped[i * n + i] += mu * jtj[i * n + i].max(1e-12);
}
let Some(delta) = solve_spd(&damped, &jtr, n) else {
mu *= nu_growth;
continue;
};
let trial: Vec<f64> = params
.iter()
.zip(delta.iter())
.map(|(&p, &d)| p + d)
.collect();
let trial_obs = residual(&trial);
let trial_cost = cost_of(&trial_obs);
let predicted: f64 = (0..n)
.map(|i| delta[i] * (mu * jtj[i * n + i].max(1e-12) * delta[i] + jtr[i]))
.sum::<f64>()
/ 2.0;
let actual = cost - trial_cost;
let gain = if predicted > 0.0 {
actual / predicted
} else {
-1.0
};
if gain > 0.0 && trial_cost < cost {
let step_norm = delta.iter().map(|d| d * d).sum::<f64>().sqrt();
params = trial;
obs = trial_obs;
cost = trial_cost;
let shrink = (1.0 - (2.0 * gain - 1.0).powi(3)).max(1.0 / 3.0);
mu *= shrink;
accepted = true;
if step_norm <= tol {
converged = true;
}
break;
}
mu *= nu_growth;
}
if !accepted {
converged = true;
break;
}
if converged {
break;
}
}
LevenbergMarquardtResult {
params,
cost,
iterations,
converged,
}
}
#[cfg(test)]
mod tests {
use super::*;
fn rosenbrock(x: &[f64]) -> f64 {
let a = 1.0 - x[0];
let b = x[1] - x[0] * x[0];
a * a + 100.0 * b * b
}
#[test]
fn nelder_mead_minimises_sphere() {
let res = nelder_mead(|x| x[0] * x[0] + x[1] * x[1], &[3.0, -2.0], 1e-14, 1000);
assert!(res.fx < 1e-10, "fx = {}", res.fx);
assert!(res.x[0].abs() < 1e-5);
assert!(res.x[1].abs() < 1e-5);
assert!(res.converged);
}
#[test]
fn nelder_mead_minimises_rosenbrock() {
let res = nelder_mead(rosenbrock, &[-1.2, 1.0], 1e-14, 5000);
assert!((res.x[0] - 1.0).abs() < 1e-3, "x0 = {}", res.x[0]);
assert!((res.x[1] - 1.0).abs() < 1e-3, "x1 = {}", res.x[1]);
assert!(res.fx < 1e-6);
}
#[test]
fn nelder_mead_handles_one_dimension() {
let res = nelder_mead(|x| (x[0] - 4.0) * (x[0] - 4.0), &[0.0], 1e-14, 1000);
assert!((res.x[0] - 4.0).abs() < 1e-5);
}
#[test]
fn nelder_mead_empty_start() {
let res = nelder_mead(|_| 7.0, &[], 1e-12, 10);
assert!((res.fx - 7.0).abs() < 1e-15);
assert!(res.converged);
}
#[test]
fn brent_finds_sqrt_two() {
let root = brent_root(|x| x * x - 2.0, 0.0, 2.0, 1e-14, 200).unwrap();
assert!((root - 2.0_f64.sqrt()).abs() < 1e-12);
}
#[test]
fn brent_finds_cubic_root() {
let root = brent_root(|x| x * x * x - x - 2.0, 1.0, 2.0, 1e-14, 200).unwrap();
assert!((root - 1.521_379_706_804_567_6).abs() < 1e-10);
}
#[test]
fn brent_endpoint_root() {
let root = brent_root(|x| x - 3.0, 3.0, 5.0, 1e-12, 50).unwrap();
assert!((root - 3.0).abs() < 1e-15);
}
#[test]
fn brent_rejects_no_bracket() {
assert!(brent_root(|x| x * x + 1.0, -1.0, 1.0, 1e-12, 50).is_none());
}
#[test]
fn brent_finds_transcendental_root() {
let root = brent_root(|x| x.cos() - x, 0.0, 1.0, 1e-14, 200).unwrap();
assert!((root - 0.739_085_133_215_160_6).abs() < 1e-10);
}
#[test]
fn solve_spd_3_identity() {
let x = solve_spd_3(&[1.0, 0.0, 0.0, 1.0, 0.0, 1.0], &[2.0, 3.0, 5.0]).unwrap();
assert!((x[0] - 2.0).abs() < 1e-15);
assert!((x[1] - 3.0).abs() < 1e-15);
assert!((x[2] - 5.0).abs() < 1e-15);
}
#[test]
fn solve_spd_3_known_system() {
let a = [4.0, 1.0, 1.0, 3.0, 0.0, 2.0];
let rhs = [4.0 + 2.0 + 3.0, 1.0 + 6.0 + 0.0, 1.0 + 0.0 + 6.0];
let x = solve_spd_3(&a, &rhs).unwrap();
assert!((x[0] - 1.0).abs() < 1e-12);
assert!((x[1] - 2.0).abs() < 1e-12);
assert!((x[2] - 3.0).abs() < 1e-12);
}
#[test]
fn solve_spd_rejects_non_spd() {
assert!(solve_spd(&[1.0, 2.0, 2.0, 1.0], &[1.0, 1.0], 2).is_none());
}
#[test]
fn solve_spd_rejects_bad_dims() {
assert!(solve_spd(&[1.0, 0.0, 0.0, 1.0], &[1.0], 2).is_none());
}
#[test]
fn levenberg_marquardt_fits_mean() {
let data = [1.0_f64, 2.0, 3.0, 4.0];
let res = levenberg_marquardt(
|p: &[f64]| {
data.iter()
.map(|&y| (p[0] - y, 1.0, vec![1.0]))
.collect::<Vec<_>>()
},
&[0.0],
1e-14,
200,
);
assert!((res.params[0] - 2.5).abs() < 1e-8);
assert!(res.converged);
}
#[test]
fn levenberg_marquardt_fits_line() {
let xs = [0.0_f64, 1.0, 2.0, 3.0, 4.0];
let res = levenberg_marquardt(
|p: &[f64]| {
xs.iter()
.map(|&x| {
let y = 2.0 + 3.0 * x;
(p[0] + p[1] * x - y, 1.0, vec![1.0, x])
})
.collect::<Vec<_>>()
},
&[0.0, 0.0],
1e-14,
300,
);
assert!((res.params[0] - 2.0).abs() < 1e-6, "a = {}", res.params[0]);
assert!((res.params[1] - 3.0).abs() < 1e-6, "b = {}", res.params[1]);
}
#[test]
fn levenberg_marquardt_nonlinear_exponential() {
let xs = [0.0_f64, 0.5, 1.0, 1.5, 2.0];
let res = levenberg_marquardt(
|p: &[f64]| {
xs.iter()
.map(|&x| {
let model = (p[0] * x).exp();
let y = (0.5 * x).exp();
(model - y, 1.0, vec![x * model])
})
.collect::<Vec<_>>()
},
&[0.1],
1e-14,
300,
);
assert!((res.params[0] - 0.5).abs() < 1e-5, "a = {}", res.params[0]);
}
}