use super::types::{OptResult, Sr1Config};
use crate::error::OptimizeError;
fn mat_vec(a: &[f64], x: &[f64], n: usize) -> Vec<f64> {
(0..n)
.map(|i| (0..n).map(|j| a[i * n + j] * x[j]).sum::<f64>())
.collect()
}
fn sym_mat_vec(a: &[f64], x: &[f64], n: usize) -> Vec<f64> {
mat_vec(a, x, n)
}
fn add_outer(a: &mut Vec<f64>, u: &[f64], v: &[f64], n: usize, scale: f64) {
for i in 0..n {
for j in 0..n {
a[i * n + j] += scale * u[i] * v[j];
}
}
}
pub fn sr1_update_dense(b: &mut Vec<f64>, s: &[f64], y: &[f64], n: usize, skip_tol: f64) -> bool {
let bs = sym_mat_vec(b, s, n);
let r: Vec<f64> = (0..n).map(|i| y[i] - bs[i]).collect();
let rs: f64 = r.iter().zip(s.iter()).map(|(ri, si)| ri * si).sum(); let r_norm = r.iter().map(|ri| ri * ri).sum::<f64>().sqrt();
let s_norm = s.iter().map(|si| si * si).sum::<f64>().sqrt();
if r_norm < 1e-14 || rs.abs() < skip_tol * s_norm * r_norm {
return false; }
let inv_rs = 1.0 / rs;
add_outer(b, &r, &r, n, inv_rs);
true
}
pub fn lsr1_hv_product(
g: &[f64],
s_hist: &[Vec<f64>],
y_hist: &[Vec<f64>],
gamma: f64,
) -> Vec<f64> {
let n = g.len();
let m = s_hist.len();
let mut r: Vec<f64> = g.iter().map(|gi| gamma * gi).collect();
if m == 0 {
return r;
}
let psi: Vec<Vec<f64>> = (0..m)
.map(|i| {
(0..n)
.map(|j| s_hist[i][j] - gamma * y_hist[i][j])
.collect::<Vec<f64>>()
})
.collect();
let mut big_m = vec![0.0_f64; m * m];
for i in 0..m {
for j in 0..m {
big_m[i * m + j] = psi[i]
.iter()
.zip(y_hist[j].iter())
.map(|(pi, yj)| pi * yj)
.sum();
}
}
let psi_g: Vec<f64> = (0..m)
.map(|i| psi[i].iter().zip(g.iter()).map(|(pi, gi)| pi * gi).sum())
.collect();
let v = match gaussian_solve(&big_m, &psi_g, m) {
Some(x) => x,
None => return r, };
for i in 0..m {
for j in 0..n {
r[j] += psi[i][j] * v[i];
}
}
r
}
fn gaussian_solve(a: &[f64], b: &[f64], m: usize) -> Option<Vec<f64>> {
let mut aug = vec![0.0_f64; m * (m + 1)];
for i in 0..m {
for j in 0..m {
aug[i * (m + 1) + j] = a[i * m + j];
}
aug[i * (m + 1) + m] = b[i];
}
for col in 0..m {
let mut max_row = col;
let mut max_val = aug[col * (m + 1) + col].abs();
for row in (col + 1)..m {
let v = aug[row * (m + 1) + col].abs();
if v > max_val {
max_val = v;
max_row = row;
}
}
if max_val < 1e-14 {
return None; }
if max_row != col {
for j in 0..=(m) {
aug.swap(col * (m + 1) + j, max_row * (m + 1) + j);
}
}
let pivot = aug[col * (m + 1) + col];
for row in (col + 1)..m {
let factor = aug[row * (m + 1) + col] / pivot;
for j in col..=(m) {
let val = aug[col * (m + 1) + j] * factor;
aug[row * (m + 1) + j] -= val;
}
}
}
let mut x = vec![0.0_f64; m];
for i in (0..m).rev() {
let rhs = aug[i * (m + 1) + m];
let diag = aug[i * (m + 1) + i];
if diag.abs() < 1e-14 {
return None;
}
let sum: f64 = ((i + 1)..m).map(|j| aug[i * (m + 1) + j] * x[j]).sum();
x[i] = (rhs - sum) / diag;
}
Some(x)
}
pub fn trust_region_step(b: &[f64], g: &[f64], delta: f64, n: usize) -> Vec<f64> {
let b_vec = b.to_vec();
if let Some(d) = solve_linear_system(&b_vec, g, n) {
let d_neg: Vec<f64> = d.iter().map(|di| -di).collect();
let dnorm = d_neg.iter().map(|di| di * di).sum::<f64>().sqrt();
if dnorm <= delta {
return d_neg; }
}
let mut lam_lo = 0.0_f64;
let mut lam_hi = {
let g_norm = g.iter().map(|gi| gi * gi).sum::<f64>().sqrt();
g_norm / delta
+ b.iter()
.enumerate()
.filter(|(idx, _)| idx % (n + 1) == 0)
.map(|(_, v)| v.abs())
.fold(0.0_f64, f64::max)
};
lam_hi = lam_hi.max(1.0);
for _ in 0..50 {
let lam_mid = 0.5 * (lam_lo + lam_hi);
let mut b_reg = b_vec.clone();
for i in 0..n {
b_reg[i * n + i] += lam_mid;
}
if let Some(d) = solve_linear_system(&b_reg, g, n) {
let d_neg: Vec<f64> = d.iter().map(|di| -di).collect();
let dnorm = d_neg.iter().map(|di| di * di).sum::<f64>().sqrt();
if dnorm <= delta {
lam_hi = lam_mid;
} else {
lam_lo = lam_mid;
}
if (lam_hi - lam_lo).abs() < 1e-12 * (1.0 + lam_mid) {
return d_neg;
}
} else {
lam_lo = lam_mid;
}
}
let g_norm = g.iter().map(|gi| gi * gi).sum::<f64>().sqrt();
if g_norm < 1e-14 {
return vec![0.0; n];
}
g.iter().map(|gi| -gi * delta / g_norm).collect()
}
fn solve_linear_system(a: &[f64], b: &[f64], n: usize) -> Option<Vec<f64>> {
gaussian_solve(a, b, n)
}
pub struct Sr1Optimizer {
pub config: Sr1Config,
}
impl Sr1Optimizer {
pub fn new(config: Sr1Config) -> Self {
Self { config }
}
pub fn default_config() -> Self {
Self {
config: Sr1Config::default(),
}
}
pub fn minimize<F>(&self, f_and_g: &F, x0: &[f64]) -> Result<OptResult, OptimizeError>
where
F: Fn(&[f64]) -> (f64, Vec<f64>),
{
let n = x0.len();
let cfg = &self.config;
let m = cfg.m;
let mut x = x0.to_vec();
let (mut f_val, mut g) = f_and_g(&x);
let mut s_hist: Vec<Vec<f64>> = Vec::with_capacity(m);
let mut y_hist: Vec<Vec<f64>> = Vec::with_capacity(m);
let mut gamma = 1.0_f64;
let mut delta = cfg.delta_init;
let mut n_iter = 0usize;
let mut converged = false;
for iter in 0..cfg.max_iter {
n_iter = iter;
let g_norm = g.iter().map(|gi| gi * gi).sum::<f64>().sqrt();
if g_norm < cfg.tol {
converged = true;
break;
}
let hg = lsr1_hv_product(&g, &s_hist, &y_hist, gamma);
let hg_norm = hg.iter().map(|v| v * v).sum::<f64>().sqrt();
let d: Vec<f64> = if hg_norm > delta {
hg.iter().map(|v| -v * delta / hg_norm).collect()
} else {
hg.iter().map(|v| -v).collect()
};
let slope: f64 = g.iter().zip(d.iter()).map(|(gi, di)| gi * di).sum();
let d = if slope >= 0.0 {
let gn = g_norm.max(1e-14);
let sc = delta / gn;
g.iter().map(|gi| -gi * sc).collect::<Vec<f64>>()
} else {
d
};
let x_new: Vec<f64> = x.iter().zip(d.iter()).map(|(xi, di)| xi + di).collect();
let (f_new, g_new) = f_and_g(&x_new);
let actual_red = f_val - f_new;
let gd: f64 = g.iter().zip(d.iter()).map(|(gi, di)| gi * di).sum();
let predicted_red = -gd;
let rho = if predicted_red.abs() < 1e-14 {
0.0
} else {
actual_red / predicted_red
};
if rho > cfg.eta {
let s: Vec<f64> = d.clone();
let y: Vec<f64> = (0..n).map(|i| g_new[i] - g[i]).collect();
let bs = lsr1_hv_product(&s, &s_hist, &y_hist, 1.0 / gamma);
let r: Vec<f64> = (0..n).map(|i| y[i] - bs[i]).collect();
let rs: f64 = r.iter().zip(s.iter()).map(|(ri, si)| ri * si).sum();
let r_norm = r.iter().map(|ri| ri * ri).sum::<f64>().sqrt();
let s_norm = s.iter().map(|si| si * si).sum::<f64>().sqrt();
if rs.abs() >= cfg.skip_tol * s_norm * r_norm {
let sy: f64 = s.iter().zip(y.iter()).map(|(si, yi)| si * yi).sum();
let yy: f64 = y.iter().map(|yi| yi * yi).sum::<f64>();
if yy > 1e-14 {
gamma = sy / yy; }
if s_hist.len() == m {
s_hist.remove(0);
y_hist.remove(0);
}
s_hist.push(s);
y_hist.push(y);
}
x = x_new;
f_val = f_new;
g = g_new;
}
if rho < 0.25 {
delta *= 0.25;
} else if rho > 0.75
&& (d.iter().map(|di| di * di).sum::<f64>().sqrt() - delta).abs() < 1e-10
{
delta = (2.0 * delta).min(cfg.delta_max);
}
if delta < 1e-12 {
break; }
}
let grad_norm = g.iter().map(|gi| gi * gi).sum::<f64>().sqrt();
Ok(OptResult {
x,
f_val,
grad_norm,
n_iter,
converged,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::second_order::types::Sr1Config;
fn quadratic(x: &[f64]) -> (f64, Vec<f64>) {
let f: f64 = x
.iter()
.enumerate()
.map(|(i, xi)| 0.5 * (i as f64 + 1.0) * xi * xi)
.sum();
let g: Vec<f64> = x
.iter()
.enumerate()
.map(|(i, xi)| (i as f64 + 1.0) * xi)
.collect();
(f, g)
}
#[test]
fn test_sr1_update_formula() {
let n = 3;
let mut b = vec![0.0_f64; n * n];
for i in 0..n {
b[i * n + i] = 1.0;
}
let s = vec![1.0, 0.0, 0.0];
let y = vec![2.0, 0.0, 0.0]; let updated = sr1_update_dense(&mut b, &s, &y, n, 1e-8);
assert!(updated, "SR1 update should proceed");
assert!(
(b[0] - 2.0).abs() < 1e-10,
"B[0,0] should be 2, got {}",
b[0]
);
}
#[test]
fn test_sr1_skip_bad_curvature() {
let n = 2;
let mut b = vec![1.0, 0.0, 0.0, 1.0]; let s = vec![1.0, 0.0];
let y = vec![1.0, 0.0]; let updated = sr1_update_dense(&mut b, &s, &y, n, 1e-8);
assert!(!updated, "SR1 update should be skipped (zero denominator)");
}
#[test]
fn test_sr1_trust_region() {
let n = 2;
let b = vec![2.0, 0.0, 0.0, 3.0]; let g = vec![1.0, 1.0];
let delta = 0.5_f64;
let d = trust_region_step(&b, &g, delta, n);
let d_norm = d.iter().map(|di| di * di).sum::<f64>().sqrt();
assert!(
d_norm <= delta + 1e-9,
"Trust region violated: ‖d‖={} > δ={}",
d_norm,
delta
);
}
#[test]
fn test_sr1_quadratic() {
let opt = Sr1Optimizer::default_config();
let x0 = vec![3.0, -2.0, 1.0];
let result = opt.minimize(&quadratic, &x0).expect("SR1 minimize failed");
for xi in &result.x {
assert!(xi.abs() < 0.01, "Expected x≈0, got {}", xi);
}
}
#[test]
fn test_sr1_positive_definite_approx() {
let n = 2;
let mut b = vec![1.0, 0.0, 0.0, 1.0];
let pairs = vec![
(vec![1.0, 0.0], vec![2.0, 0.0]),
(vec![0.0, 1.0], vec![0.0, 3.0]),
];
for (s, y) in &pairs {
sr1_update_dense(&mut b, s, y, n, 1e-8);
}
for (s, _y) in &pairs {
let bs = mat_vec(&b, s, n);
let sts: f64 = s.iter().zip(bs.iter()).map(|(si, bsi)| si * bsi).sum();
assert!(sts > 0.0, "B should be positive in s direction");
}
}
#[test]
fn test_sr1_symmetric_update() {
let n = 3;
let mut b = vec![2.0, 1.0, 0.0, 1.0, 3.0, 0.5, 0.0, 0.5, 4.0];
let s = vec![0.5, -0.3, 0.1];
let y = vec![1.5, 0.2, 0.4];
sr1_update_dense(&mut b, &s, &y, n, 1e-8);
for i in 0..n {
for j in 0..n {
assert!(
(b[i * n + j] - b[j * n + i]).abs() < 1e-10,
"B not symmetric at ({},{}) vs ({},{}) : {} vs {}",
i,
j,
j,
i,
b[i * n + j],
b[j * n + i]
);
}
}
}
}