use super::lbfgsb::hv_product;
use super::types::{OptResult, SlbfgsConfig};
use crate::error::OptimizeError;
pub struct Lcg {
state: u64,
}
impl Lcg {
pub fn new(seed: u64) -> Self {
Self { state: seed }
}
pub fn next_u32(&mut self) -> u32 {
self.state = self
.state
.wrapping_mul(1_664_525)
.wrapping_add(1_013_904_223)
& 0xFFFF_FFFF;
self.state as u32
}
pub fn next_usize(&mut self, n: usize) -> usize {
(self.next_u32() as usize) % n
}
pub fn sample_without_replacement(&mut self, n: usize, k: usize, buf: &mut Vec<usize>) {
buf.clear();
if k == 0 || n == 0 {
return;
}
let k = k.min(n);
let mut pool: Vec<usize> = (0..n).collect();
for i in 0..k {
let j = i + self.next_usize(n - i);
pool.swap(i, j);
buf.push(pool[i]);
}
}
}
fn svrg_gradient(
stoch_f_and_g: &dyn Fn(&[f64], &[usize]) -> (f64, Vec<f64>),
x_k: &[f64],
x_snap: &[f64],
g_snap: &[f64], batch: &[usize],
) -> Vec<f64> {
let (_, g_k) = stoch_f_and_g(x_k, batch);
let (_, g_s) = stoch_f_and_g(x_snap, batch);
g_k.iter()
.zip(g_s.iter())
.zip(g_snap.iter())
.map(|((gki, gsi), gfi)| gki - gsi + gfi)
.collect()
}
fn curvature_y(
stoch_f_and_g: &dyn Fn(&[f64], &[usize]) -> (f64, Vec<f64>),
x_new: &[f64],
x_old: &[f64],
batch: &[usize],
) -> Vec<f64> {
let n = x_new.len();
if batch.is_empty() {
return vec![0.0; n];
}
let (_, g_new) = stoch_f_and_g(x_new, batch);
let (_, g_old) = stoch_f_and_g(x_old, batch);
g_new
.iter()
.zip(g_old.iter())
.map(|(gn, go)| gn - go)
.collect()
}
pub struct SlbfgsOptimizer {
pub config: SlbfgsConfig,
}
impl SlbfgsOptimizer {
pub fn new(config: SlbfgsConfig) -> Self {
Self { config }
}
pub fn default_config() -> Self {
Self {
config: SlbfgsConfig::default(),
}
}
pub fn minimize(
&self,
stoch_f_and_g: &dyn Fn(&[f64], &[usize]) -> (f64, Vec<f64>),
full_grad_fn: &dyn Fn(&[f64]) -> (f64, Vec<f64>),
n_samples: usize,
x0: &[f64],
) -> Result<OptResult, OptimizeError> {
let n = x0.len();
let cfg = &self.config;
let m = cfg.m;
if n_samples == 0 {
return Err(OptimizeError::ValueError(
"n_samples must be positive".to_string(),
));
}
let mut x = x0.to_vec();
let mut rng = Lcg::new(cfg.seed);
let mut s_hist: Vec<Vec<f64>> = Vec::with_capacity(m);
let mut y_hist: Vec<Vec<f64>> = Vec::with_capacity(m);
let mut rho_hist: Vec<f64> = Vec::with_capacity(m);
let mut gamma = 1.0_f64;
let mut x_snap = x.clone();
let (mut f_snap, mut g_snap) = full_grad_fn(&x_snap);
let mut n_iter = 0usize;
let mut converged = false;
let mut batch_buf: Vec<usize> = Vec::with_capacity(cfg.batch_size);
let mut curv_batch_buf: Vec<usize> = Vec::with_capacity(cfg.curvature_batch_size);
let mut best_x = x.clone();
let mut best_f = f_snap;
for iter in 0..cfg.max_iter {
n_iter = iter;
if cfg.variance_reduction && iter % cfg.snapshot_freq == 0 {
let (fs, gs) = full_grad_fn(&x);
x_snap = x.clone();
f_snap = fs;
g_snap = gs;
}
let gn = g_snap.iter().map(|g| g * g).sum::<f64>().sqrt();
if gn < cfg.tol {
converged = true;
break;
}
rng.sample_without_replacement(n_samples, cfg.batch_size, &mut batch_buf);
let g_k = if cfg.variance_reduction {
svrg_gradient(stoch_f_and_g, &x, &x_snap, &g_snap, &batch_buf)
} else {
let (_, gk) = stoch_f_and_g(&x, &batch_buf);
gk
};
let hg = hv_product(&g_k, &s_hist, &y_hist, &rho_hist, gamma);
let d: Vec<f64> = hg.iter().map(|v| -v).collect();
let slope: f64 = g_k.iter().zip(d.iter()).map(|(gi, di)| gi * di).sum();
let d = if slope >= 0.0 {
g_k.iter().map(|gi| -gi).collect::<Vec<f64>>()
} else {
d
};
let x_new: Vec<f64> = x
.iter()
.zip(d.iter())
.map(|(xi, di)| xi + cfg.lr * di)
.collect();
rng.sample_without_replacement(
n_samples,
cfg.curvature_batch_size,
&mut curv_batch_buf,
);
let s_k: Vec<f64> = (0..n).map(|i| x_new[i] - x[i]).collect();
let y_k = curvature_y(stoch_f_and_g, &x_new, &x, &curv_batch_buf);
let sy: f64 = s_k.iter().zip(y_k.iter()).map(|(si, yi)| si * yi).sum();
if sy > 1e-14 * s_k.iter().map(|si| si * si).sum::<f64>().sqrt() {
if s_hist.len() == m {
s_hist.remove(0);
y_hist.remove(0);
rho_hist.remove(0);
}
let yy: f64 = y_k.iter().map(|yi| yi * yi).sum();
if yy > 1e-14 {
gamma = sy / yy;
}
rho_hist.push(1.0 / sy);
s_hist.push(s_k);
y_hist.push(y_k);
}
x = x_new;
let (f_curr, _) = full_grad_fn(&x);
if f_curr < best_f {
best_f = f_curr;
best_x = x.clone();
}
}
let (_, g_final) = full_grad_fn(&best_x);
let grad_norm = g_final.iter().map(|gi| gi * gi).sum::<f64>().sqrt();
Ok(OptResult {
x: best_x,
f_val: best_f,
grad_norm,
n_iter,
converged,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::second_order::types::SlbfgsConfig;
fn stoch_quad(x: &[f64], batch: &[usize]) -> (f64, Vec<f64>) {
let n = x.len();
if batch.is_empty() {
let f: f64 = x.iter().map(|xi| (xi - 1.0).powi(2)).sum::<f64>() / n as f64;
let g: Vec<f64> = x.iter().map(|xi| 2.0 * (xi - 1.0) / n as f64).collect();
return (f, g);
}
let bs = batch.len() as f64;
let f: f64 = batch.iter().map(|&i| (x[i % n] - 1.0).powi(2)).sum::<f64>() / bs;
let mut g = vec![0.0_f64; n];
for &idx in batch {
g[idx % n] += 2.0 * (x[idx % n] - 1.0) / bs;
}
(f, g)
}
fn full_quad(x: &[f64]) -> (f64, Vec<f64>) {
let n = x.len();
let all: Vec<usize> = (0..n).collect();
stoch_quad(x, &all)
}
#[test]
fn test_slbfgs_gradient_variance_reduction() {
let x_star = vec![1.0; 4];
let x_snap = vec![1.0; 4];
let (_, g_snap) = full_quad(&x_snap);
let batch = vec![0, 1, 2, 3];
let g_corr = svrg_gradient(&stoch_quad, &x_star, &x_snap, &g_snap, &batch);
for gi in &g_corr {
assert!(
gi.abs() < 1e-12,
"Corrected gradient should be zero at optimum: got {}",
gi
);
}
}
#[test]
fn test_slbfgs_curvature_condition() {
let x_old = vec![2.0, 3.0];
let x_new = vec![1.5, 2.5];
let all_batch: Vec<usize> = (0..2).collect();
let y = curvature_y(&stoch_quad, &x_new, &x_old, &all_batch);
let s: Vec<f64> = x_new
.iter()
.zip(x_old.iter())
.map(|(xn, xo)| xn - xo)
.collect();
let sy: f64 = s.iter().zip(y.iter()).map(|(si, yi)| si * yi).sum();
assert!(sy > 0.0, "Curvature condition y^T s > 0 violated: {}", sy);
}
#[test]
fn test_slbfgs_stochastic_convergence() {
let mut cfg = SlbfgsConfig::default();
cfg.max_iter = 300;
cfg.lr = 0.05;
cfg.batch_size = 4;
cfg.curvature_batch_size = 8;
cfg.variance_reduction = true;
cfg.tol = 1e-4;
let opt = SlbfgsOptimizer::new(cfg);
let x0 = vec![0.0_f64; 4];
let result = opt
.minimize(&stoch_quad, &full_quad, 4, &x0)
.expect("S-L-BFGS failed");
for xi in &result.x {
assert!(
(xi - 1.0).abs() < 0.2,
"S-L-BFGS did not converge: x={:?}",
result.x
);
}
}
#[test]
fn test_second_order_config_default() {
use crate::second_order::types::{LbfgsBConfig, SlbfgsConfig, Sr1Config};
let _c1 = LbfgsBConfig::default();
let _c2 = Sr1Config::default();
let _c3 = SlbfgsConfig::default();
}
#[test]
fn test_slbfgs_batch_selection() {
let mut rng = Lcg::new(12345);
let mut buf = Vec::new();
rng.sample_without_replacement(100, 10, &mut buf);
assert_eq!(buf.len(), 10);
let mut sorted = buf.clone();
sorted.sort_unstable();
sorted.dedup();
assert_eq!(sorted.len(), 10, "Duplicate indices in batch selection");
for &idx in &buf {
assert!(idx < 100, "Index out of bounds: {}", idx);
}
}
}