use crate::error::{OptimizeError, OptimizeResult};
use scirs2_core::ndarray::{Array1, ArrayView1};
#[derive(Debug, Clone)]
pub struct RobbinsMonroResult {
pub x: Array1<f64>,
pub residual: f64,
pub n_iter: usize,
pub converged: bool,
}
#[derive(Debug, Clone)]
pub struct RobbinsMonroOptions {
pub max_iter: usize,
pub tol: f64,
pub alpha: f64,
pub a: f64,
}
impl Default for RobbinsMonroOptions {
fn default() -> Self {
Self {
max_iter: 10_000,
tol: 1e-6,
alpha: 1.0,
a: 1.0,
}
}
}
pub fn robbins_monro<M>(
m: &mut M,
x0: &ArrayView1<f64>,
opts: &RobbinsMonroOptions,
) -> OptimizeResult<RobbinsMonroResult>
where
M: FnMut(&ArrayView1<f64>) -> Array1<f64>,
{
let n = x0.len();
if n == 0 {
return Err(OptimizeError::ValueError(
"x0 must be non-empty".to_string(),
));
}
let mut x = x0.to_owned();
let mut converged = false;
let mut residual = f64::INFINITY;
for k in 1..=opts.max_iter {
let mk = m(&x.view());
if mk.len() != n {
return Err(OptimizeError::ValueError(format!(
"M returned length {} but x has length {}",
mk.len(),
n
)));
}
let ak = opts.a / (k as f64).powf(opts.alpha);
let mut step_norm = 0.0_f64;
for i in 0..n {
let step = ak * mk[i];
x[i] -= step;
step_norm += step * step;
}
residual = step_norm.sqrt();
if residual < opts.tol {
converged = true;
residual = mk.iter().map(|v| v * v).sum::<f64>().sqrt();
return Ok(RobbinsMonroResult {
x,
residual,
n_iter: k,
converged,
});
}
}
let mk_final = m(&x.view());
residual = mk_final.iter().map(|v| v * v).sum::<f64>().sqrt();
Ok(RobbinsMonroResult {
x,
residual,
n_iter: opts.max_iter,
converged,
})
}
#[derive(Debug, Clone)]
pub struct KieferWolfowitzResult {
pub x: Array1<f64>,
pub fun: f64,
pub n_iter: usize,
pub converged: bool,
}
#[derive(Debug, Clone)]
pub struct KieferWolfowitzOptions {
pub max_iter: usize,
pub tol: f64,
pub alpha: f64,
pub gamma: f64,
pub a: f64,
pub c: f64,
}
impl Default for KieferWolfowitzOptions {
fn default() -> Self {
Self {
max_iter: 10_000,
tol: 1e-6,
alpha: 0.602,
gamma: 0.101,
a: 0.1,
c: 0.1,
}
}
}
pub fn kiefer_wolfowitz<L>(
loss: &mut L,
x0: &ArrayView1<f64>,
opts: &KieferWolfowitzOptions,
) -> OptimizeResult<KieferWolfowitzResult>
where
L: FnMut(&ArrayView1<f64>) -> f64,
{
let n = x0.len();
if n == 0 {
return Err(OptimizeError::ValueError(
"x0 must be non-empty".to_string(),
));
}
let mut x = x0.to_owned();
let mut converged = false;
for k in 1..=opts.max_iter {
let ak = opts.a / (k as f64).powf(opts.alpha);
let ck = opts.c / (k as f64).powf(opts.gamma);
let mut grad = Array1::<f64>::zeros(n);
for i in 0..n {
let mut x_fwd = x.clone();
let mut x_bwd = x.clone();
x_fwd[i] += ck;
x_bwd[i] -= ck;
grad[i] = (loss(&x_fwd.view()) - loss(&x_bwd.view())) / (2.0 * ck);
}
let mut step_norm = 0.0_f64;
for i in 0..n {
let step = ak * grad[i];
x[i] -= step;
step_norm += step * step;
}
if step_norm.sqrt() < opts.tol {
converged = true;
let fun = loss(&x.view());
return Ok(KieferWolfowitzResult {
x,
fun,
n_iter: k,
converged,
});
}
}
let fun = loss(&x.view());
Ok(KieferWolfowitzResult {
x,
fun,
n_iter: opts.max_iter,
converged,
})
}
#[derive(Debug, Clone)]
pub struct SpsaOptions {
pub max_iter: usize,
pub tol: f64,
pub alpha: f64,
pub gamma: f64,
pub a: f64,
pub big_a: f64,
pub c: f64,
}
impl Default for SpsaOptions {
fn default() -> Self {
Self {
max_iter: 5_000,
tol: 1e-6,
alpha: 0.602,
gamma: 0.101,
a: 0.1,
big_a: 100.0,
c: 0.1,
}
}
}
#[derive(Debug, Clone)]
pub struct SpsaResult {
pub x: Array1<f64>,
pub fun: f64,
pub n_iter: usize,
pub converged: bool,
}
pub fn spsa_step<F>(
f: &mut F,
x: &mut Array1<f64>,
k: usize,
opts: &SpsaOptions,
rng_state: &mut u64,
) -> f64
where
F: FnMut(&ArrayView1<f64>) -> f64,
{
let n = x.len();
let ak = opts.a / (opts.big_a + k as f64).powf(opts.alpha);
let ck = opts.c / (k as f64).powf(opts.gamma);
let mut delta = Array1::<f64>::zeros(n);
for i in 0..n {
*rng_state = rng_state
.wrapping_mul(6364136223846793005)
.wrapping_add(1442695040888963407);
delta[i] = if (*rng_state >> 63) == 0 { 1.0 } else { -1.0 };
}
let x_fwd: Array1<f64> = x.iter().zip(delta.iter()).map(|(&xi, &di)| xi + ck * di).collect();
let x_bwd: Array1<f64> = x.iter().zip(delta.iter()).map(|(&xi, &di)| xi - ck * di).collect();
let f_fwd = f(&x_fwd.view());
let f_bwd = f(&x_bwd.view());
let diff = (f_fwd - f_bwd) / (2.0 * ck);
let mut step_sq = 0.0_f64;
for i in 0..n {
let gi = diff / delta[i]; let step = ak * gi;
x[i] -= step;
step_sq += step * step;
}
step_sq.sqrt()
}
pub fn spsa_minimize<F>(
f: &mut F,
x0: &ArrayView1<f64>,
opts: &SpsaOptions,
) -> OptimizeResult<SpsaResult>
where
F: FnMut(&ArrayView1<f64>) -> f64,
{
if x0.is_empty() {
return Err(OptimizeError::ValueError(
"x0 must be non-empty".to_string(),
));
}
let mut x = x0.to_owned();
let mut rng_state: u64 = 12345678901234567;
let mut converged = false;
for k in 1..=opts.max_iter {
let step_norm = spsa_step(f, &mut x, k, opts, &mut rng_state);
if step_norm < opts.tol {
converged = true;
let fun = f(&x.view());
return Ok(SpsaResult {
x,
fun,
n_iter: k,
converged,
});
}
}
let fun = f(&x.view());
Ok(SpsaResult {
x,
fun,
n_iter: opts.max_iter,
converged,
})
}
#[cfg(test)]
mod tests {
use super::*;
use scirs2_core::ndarray::array;
#[test]
fn test_robbins_monro_linear() {
let mut m = |x: &ArrayView1<f64>| array![x[0] - 2.0];
let x0 = array![0.0];
let opts = RobbinsMonroOptions {
max_iter: 50_000,
tol: 1e-4,
a: 1.0,
alpha: 1.0,
};
let res = robbins_monro(&mut m, &x0.view(), &opts).expect("failed to create res");
assert!(
(res.x[0] - 2.0).abs() < 0.1,
"expected x* ≈ 2.0, got {}",
res.x[0]
);
}
#[test]
fn test_kiefer_wolfowitz_quadratic() {
let mut loss = |x: &ArrayView1<f64>| (x[0] - 3.0).powi(2);
let x0 = array![0.0];
let opts = KieferWolfowitzOptions {
max_iter: 20_000,
tol: 1e-5,
..Default::default()
};
let res = kiefer_wolfowitz(&mut loss, &x0.view(), &opts).expect("failed to create res");
assert!(
(res.x[0] - 3.0).abs() < 0.2,
"expected x* ≈ 3.0, got {}",
res.x[0]
);
}
#[test]
fn test_spsa_quadratic() {
let mut f = |x: &ArrayView1<f64>| (x[0] - 1.0).powi(2) + (x[1] - 2.0).powi(2);
let x0 = array![0.0, 0.0];
let opts = SpsaOptions {
max_iter: 10_000,
tol: 1e-5,
a: 0.5,
big_a: 50.0,
c: 0.2,
..Default::default()
};
let res = spsa_minimize(&mut f, &x0.view(), &opts).expect("failed to create res");
assert!(
(res.x[0] - 1.0).abs() < 0.3,
"expected x[0] ≈ 1.0, got {}",
res.x[0]
);
assert!(
(res.x[1] - 2.0).abs() < 0.3,
"expected x[1] ≈ 2.0, got {}",
res.x[1]
);
}
}