use crate::hess::r#trait::HessianUpdater;
use crate::ipopt_cq::IpoptCqHandle;
use crate::ipopt_data::IpoptDataHandle;
use pounce_common::types::{Index, Number};
use pounce_linalg::compound_vector::CompoundVector;
use pounce_linalg::dense_vector::{DenseVector, DenseVectorSpace};
use pounce_linalg::low_rank_update_sym_matrix::LowRankUpdateSymMatrixSpace;
use pounce_linalg::multi_vector_matrix::{MultiVectorMatrix, MultiVectorMatrixSpace};
use pounce_linalg::Vector;
use std::rc::Rc;
#[derive(Debug, Clone)]
pub struct CurvaturePair {
pub s: Rc<dyn Vector>,
pub y: Rc<dyn Vector>,
pub s_dot_y: Number,
pub s_norm: Number,
pub y_norm: Number,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum UpdateType {
Bfgs,
Sr1,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum InitialApprox {
Identity,
Scalar1,
Scalar2,
}
pub struct LimMemQuasiNewtonUpdater {
pub update_type: UpdateType,
pub initial_approx: InitialApprox,
pub max_history: i32,
pub init_val_max: Number,
pub init_val_min: Number,
pub history: Vec<CurvaturePair>,
pub last_x: Option<Rc<dyn Vector>>,
pub last_grad_f: Option<Rc<dyn Vector>>,
pub last_jac_c: Option<Rc<dyn pounce_linalg::matrix::Matrix>>,
pub last_jac_d: Option<Rc<dyn pounce_linalg::matrix::Matrix>>,
}
impl Default for LimMemQuasiNewtonUpdater {
fn default() -> Self {
Self {
update_type: UpdateType::Bfgs,
initial_approx: InitialApprox::Scalar2,
max_history: 6,
init_val_max: 1e8,
init_val_min: 1e-8,
history: Vec::new(),
last_x: None,
last_grad_f: None,
last_jac_c: None,
last_jac_d: None,
}
}
}
impl LimMemQuasiNewtonUpdater {
pub fn new() -> Self {
Self::default()
}
pub fn ingest_pair(&mut self, s: Rc<dyn Vector>, y: Rc<dyn Vector>) -> bool {
let s_dot_y = s.dot(&*y);
let s_norm = s.nrm2();
let y_norm = y.nrm2();
let accept = match self.update_type {
UpdateType::Bfgs => bfgs_curvature_pair_ok(s_dot_y, s_norm, y_norm),
UpdateType::Sr1 => {
sr1_denominator_ok(s_dot_y, s_norm, y_norm)
}
};
if !accept {
return false;
}
self.history.push(CurvaturePair {
s,
y,
s_dot_y,
s_norm,
y_norm,
});
while self.history.len() > self.max_history.max(0) as usize {
self.history.remove(0);
}
true
}
}
impl HessianUpdater for LimMemQuasiNewtonUpdater {
fn update_hessian(&mut self, data: &IpoptDataHandle, cq: &IpoptCqHandle) -> bool {
let (curr_x, curr_y_c, curr_y_d) = match data.borrow().curr.as_ref() {
Some(c) => (c.x.clone(), c.y_c.clone(), c.y_d.clone()),
None => return true,
};
let curr_grad_f = cq.borrow().curr_grad_f();
let curr_jac_c = cq.borrow().curr_jac_c();
let curr_jac_d = cq.borrow().curr_jac_d();
if let (Some(prev_x), Some(prev_grad_f), Some(prev_jac_c), Some(prev_jac_d)) = (
self.last_x.clone(),
self.last_grad_f.clone(),
self.last_jac_c.clone(),
self.last_jac_d.clone(),
) {
let mut s = curr_x.make_new();
s.add_two_vectors(1.0, &*curr_x, -1.0, &*prev_x, 0.0);
let mut y = curr_x.make_new();
y.add_two_vectors(1.0, &*curr_grad_f, -1.0, &*prev_grad_f, 0.0);
curr_jac_c.trans_mult_vector(1.0, &*curr_y_c, 1.0, &mut *y);
prev_jac_c.trans_mult_vector(-1.0, &*curr_y_c, 1.0, &mut *y);
curr_jac_d.trans_mult_vector(1.0, &*curr_y_d, 1.0, &mut *y);
prev_jac_d.trans_mult_vector(-1.0, &*curr_y_d, 1.0, &mut *y);
self.ingest_pair(Rc::from(s), Rc::from(y));
}
self.last_x = Some(Rc::clone(&curr_x));
self.last_grad_f = Some(Rc::clone(&curr_grad_f));
self.last_jac_c = Some(Rc::clone(&curr_jac_c));
self.last_jac_d = Some(Rc::clone(&curr_jac_d));
let n_idx = curr_x.dim();
let nu = n_idx as usize;
let sigma = match self.update_type {
UpdateType::Bfgs => self.compute_sigma_bfgs(),
UpdateType::Sr1 => self.compute_sigma_bfgs(),
};
let (v_cols, u_cols) = self.build_low_rank(sigma, nu);
let col_space = DenseVectorSpace::new(n_idx);
let mut diag = curr_x.make_new();
diag.set(sigma);
let lr_space = LowRankUpdateSymMatrixSpace::new(n_idx, None, false);
let mut lr = lr_space.make_new_low_rank();
lr.set_diag(Rc::from(diag));
if let Some(mvm) = build_multi_vector(&col_space, curr_x.as_ref(), &v_cols) {
lr.set_v(Rc::new(mvm));
}
if let Some(mvm) = build_multi_vector(&col_space, curr_x.as_ref(), &u_cols) {
lr.set_u(Rc::new(mvm));
}
data.borrow_mut().w = Some(Rc::new(lr));
true
}
}
impl LimMemQuasiNewtonUpdater {
fn compute_sigma_bfgs(&self) -> Number {
if self.history.is_empty() {
return 1.0;
}
let last = self.history.last().unwrap();
let s_dot_s = last.s_norm * last.s_norm;
let y_dot_y = last.y_norm * last.y_norm;
initial_hessian_scalar(
self.initial_approx,
s_dot_s,
last.s_dot_y,
y_dot_y,
self.init_val_min,
self.init_val_max,
)
}
fn build_low_rank(&self, sigma: Number, n: usize) -> (Vec<Vec<Number>>, Vec<Vec<Number>>) {
let mut v_cols: Vec<Vec<Number>> = Vec::new();
let mut u_cols: Vec<Vec<Number>> = Vec::new();
if n == 0 {
return (v_cols, u_cols);
}
for pair in &self.history {
let s = dense_from_vec(pair.s.as_ref(), n);
let y = dense_from_vec(pair.y.as_ref(), n);
let mut bs: Vec<Number> = s.iter().map(|&si| sigma * si).collect();
for v in &v_cols {
let c: Number = (0..n).map(|i| v[i] * s[i]).sum();
for i in 0..n {
bs[i] += c * v[i];
}
}
for u in &u_cols {
let c: Number = (0..n).map(|i| u[i] * s[i]).sum();
for i in 0..n {
bs[i] -= c * u[i];
}
}
match self.update_type {
UpdateType::Bfgs => {
let s_bs: Number = (0..n).map(|i| s[i] * bs[i]).sum();
if s_bs <= 0.0 {
continue;
}
let sy = pair.s_dot_y;
let theta = powell_damping_theta(sy, s_bs);
let sr = theta * sy + (1.0 - theta) * s_bs;
if sr <= 0.0 {
continue;
}
let r_scale = 1.0 / sr.sqrt();
let bs_scale = 1.0 / s_bs.sqrt();
v_cols.push(
(0..n)
.map(|i| (theta * y[i] + (1.0 - theta) * bs[i]) * r_scale)
.collect(),
);
u_cols.push(bs.iter().map(|&bi| bi * bs_scale).collect());
}
UpdateType::Sr1 => {
let yms: Vec<Number> = (0..n).map(|i| y[i] - bs[i]).collect();
let denom: Number = (0..n).map(|i| yms[i] * s[i]).sum();
let yms_norm: Number = yms.iter().map(|&w| w * w).sum::<Number>().sqrt();
if !sr1_denominator_ok(denom, pair.s_norm, yms_norm) {
continue;
}
let scale = 1.0 / denom.abs().sqrt();
let col: Vec<Number> = yms.iter().map(|&w| w * scale).collect();
if denom > 0.0 {
v_cols.push(col);
} else {
u_cols.push(col);
}
}
}
}
(v_cols, u_cols)
}
}
fn build_multi_vector(
col_space: &Rc<DenseVectorSpace>,
template: &dyn Vector,
cols: &[Vec<Number>],
) -> Option<MultiVectorMatrix> {
if cols.is_empty() {
return None;
}
let space = MultiVectorMatrixSpace::new(cols.len() as Index, Rc::clone(col_space));
let mut mvm = space.make_new_multi_vector();
for (k, col) in cols.iter().enumerate() {
let mut cv = template.make_new();
set_expanded(cv.as_mut(), col);
mvm.set_vector(k as Index, Rc::from(cv));
}
Some(mvm)
}
fn expanded_of(v: &dyn Vector) -> Vec<Number> {
if let Some(dv) = v.as_any().downcast_ref::<DenseVector>() {
return dv.expanded_values();
}
if let Some(cv) = v.as_any().downcast_ref::<CompoundVector>() {
let mut out = Vec::with_capacity(cv.dim() as usize);
for i in 0..cv.n_comps() {
out.extend(expanded_of(cv.comp(i)));
}
return out;
}
panic!("LimMemQuasiNewtonUpdater: unsupported primal vector type for expansion");
}
fn set_expanded(dst: &mut dyn Vector, flat: &[Number]) {
if let Some(dv) = dst.as_any_mut().downcast_mut::<DenseVector>() {
dv.set_values(flat);
return;
}
if let Some(cv) = dst.as_any_mut().downcast_mut::<CompoundVector>() {
let n = cv.n_comps();
let dims: Vec<usize> = (0..n).map(|i| cv.comp(i).dim() as usize).collect();
let mut off = 0usize;
for (i, &d) in dims.iter().enumerate() {
set_expanded(cv.comp_mut(i as Index), &flat[off..off + d]);
off += d;
}
return;
}
panic!("LimMemQuasiNewtonUpdater: unsupported primal vector type for set_expanded");
}
fn dense_from_vec(v: &dyn Vector, n: usize) -> Vec<Number> {
let ev = expanded_of(v);
debug_assert_eq!(ev.len(), n);
ev
}
pub fn initial_hessian_scalar(
init: InitialApprox,
s_dot_s: Number,
s_dot_y: Number,
y_dot_y: Number,
min_val: Number,
max_val: Number,
) -> Number {
let raw = match init {
InitialApprox::Identity => 1.0,
InitialApprox::Scalar1 => {
if s_dot_s > 0.0 {
s_dot_y / s_dot_s
} else {
1.0
}
}
InitialApprox::Scalar2 => {
if s_dot_y > 0.0 {
y_dot_y / s_dot_y
} else {
1.0
}
}
};
raw.clamp(min_val, max_val)
}
pub fn powell_damping_theta(s_dot_y: Number, s_dot_b_s: Number) -> Number {
if s_dot_y >= 0.2 * s_dot_b_s {
1.0
} else {
let denom = s_dot_b_s - s_dot_y;
if denom > 0.0 {
0.8 * s_dot_b_s / denom
} else {
1.0
}
}
}
pub fn bfgs_curvature_pair_ok(s_dot_y: Number, s_norm: Number, y_norm: Number) -> bool {
let eps = 1e-8_f64;
s_dot_y > eps * s_norm * y_norm
}
pub fn sr1_denominator_ok(yms_dot_s: Number, s_norm: Number, yms_norm: Number) -> bool {
let eps = 1e-8_f64;
yms_dot_s.abs() > eps * s_norm * yms_norm
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn identity_init_returns_one() {
assert_eq!(
initial_hessian_scalar(InitialApprox::Identity, 1.0, 1.0, 1.0, 1e-8, 1e8),
1.0
);
}
#[test]
fn scalar1_init_is_sy_over_ss() {
let v = initial_hessian_scalar(InitialApprox::Scalar1, 4.0, 2.0, 0.0, 1e-8, 1e8);
assert!((v - 0.5).abs() < 1e-15);
}
#[test]
fn scalar2_init_is_yy_over_sy() {
let v = initial_hessian_scalar(InitialApprox::Scalar2, 0.0, 2.0, 8.0, 1e-8, 1e8);
assert!((v - 4.0).abs() < 1e-15);
}
#[test]
fn init_clamped_to_max() {
let v = initial_hessian_scalar(InitialApprox::Scalar2, 0.0, 1e-20, 1.0, 1e-8, 1e8);
assert_eq!(v, 1e8);
}
#[test]
fn init_clamped_to_min() {
let v = initial_hessian_scalar(InitialApprox::Scalar2, 0.0, 1e20, 1.0, 1e-8, 1e8);
assert_eq!(v, 1e-8);
}
#[test]
fn powell_no_damping_when_curvature_ok() {
assert_eq!(powell_damping_theta(1.0, 1.0), 1.0);
}
#[test]
fn powell_damps_when_curvature_violated() {
let theta = powell_damping_theta(0.1, 1.0);
assert!((theta - 8.0 / 9.0).abs() < 1e-15);
}
#[test]
fn bfgs_skip_criterion() {
assert!(bfgs_curvature_pair_ok(1.0, 1.0, 1.0));
assert!(!bfgs_curvature_pair_ok(1e-10, 1.0, 1.0));
}
#[test]
fn sr1_skip_criterion_uses_absolute_value() {
assert!(sr1_denominator_ok(-1.0, 1.0, 1.0));
assert!(!sr1_denominator_ok(1e-10, 1.0, 1.0));
}
fn rcv(values: &[Number]) -> Rc<dyn Vector> {
let mut v = pounce_linalg::dense_vector::DenseVectorSpace::new(values.len() as i32)
.make_new_dense();
v.set(0.0);
v.values_mut().copy_from_slice(values);
Rc::new(v)
}
#[test]
fn ingest_pair_accepts_well_curved_pair() {
let mut updater = LimMemQuasiNewtonUpdater::new();
let accepted = updater.ingest_pair(rcv(&[1.0, 0.0]), rcv(&[1.0, 0.0]));
assert!(accepted);
assert_eq!(updater.history.len(), 1);
let pair = &updater.history[0];
assert!((pair.s_dot_y - 1.0).abs() < 1e-15);
assert!((pair.s_norm - 1.0).abs() < 1e-15);
assert!((pair.y_norm - 1.0).abs() < 1e-15);
}
#[test]
fn ingest_pair_skips_zero_curvature() {
let mut updater = LimMemQuasiNewtonUpdater::new();
let accepted = updater.ingest_pair(rcv(&[1.0]), rcv(&[0.0]));
assert!(!accepted);
assert!(updater.history.is_empty());
}
#[test]
fn history_caps_at_max_history() {
let mut updater = LimMemQuasiNewtonUpdater {
max_history: 2,
..LimMemQuasiNewtonUpdater::default()
};
for _ in 0..5 {
updater.ingest_pair(rcv(&[1.0]), rcv(&[1.0]));
}
assert_eq!(updater.history.len(), 2);
}
#[test]
fn sr1_path_routes_through_sr1_skip() {
let mut updater = LimMemQuasiNewtonUpdater {
update_type: UpdateType::Sr1,
..LimMemQuasiNewtonUpdater::default()
};
assert!(updater.ingest_pair(rcv(&[1.0]), rcv(&[-1.0])));
}
fn pair(s: &[Number], y: &[Number]) -> CurvaturePair {
let s_rc = rcv(s);
let y_rc = rcv(y);
let s_dot_y = s_rc.dot(&*y_rc);
let s_norm = s_rc.nrm2();
let y_norm = y_rc.nrm2();
CurvaturePair {
s: s_rc,
y: y_rc,
s_dot_y,
s_norm,
y_norm,
}
}
fn reconstruct_b(n: usize, sigma: Number, v: &[Vec<Number>], u: &[Vec<Number>]) -> Vec<Number> {
let mut b = vec![0.0_f64; n * n];
for i in 0..n {
b[i * n + i] = sigma;
}
for col in v {
for i in 0..n {
for j in 0..n {
b[i * n + j] += col[i] * col[j];
}
}
}
for col in u {
for i in 0..n {
for j in 0..n {
b[i * n + j] -= col[i] * col[j];
}
}
}
b
}
fn mat_vec(b: &[Number], n: usize, x: &[Number]) -> Vec<Number> {
(0..n)
.map(|i| (0..n).map(|j| b[i * n + j] * x[j]).sum())
.collect()
}
#[test]
fn bfgs_low_rank_recovers_hessian_action() {
let mut up = LimMemQuasiNewtonUpdater::new();
up.history.push(pair(&[1.0, 1.0], &[2.0, 5.0]));
let (v, u) = up.build_low_rank(1.0, 2);
let b = reconstruct_b(2, 1.0, &v, &u);
let bs = mat_vec(&b, 2, &[1.0, 1.0]);
assert!((bs[0] - 2.0).abs() < 1e-12, "Bs[0]={}", bs[0]);
assert!((bs[1] - 5.0).abs() < 1e-12, "Bs[1]={}", bs[1]);
}
#[test]
fn bfgs_low_rank_keeps_symmetry() {
let mut up = LimMemQuasiNewtonUpdater::new();
up.history.push(pair(&[1.0, 0.5], &[2.0, 1.0]));
up.history.push(pair(&[0.7, 1.2], &[1.0, 2.5]));
let (v, u) = up.build_low_rank(3.0, 2);
let b = reconstruct_b(2, 3.0, &v, &u);
assert!((b[1] - b[2]).abs() < 1e-12);
}
#[test]
fn sr1_low_rank_recovers_hessian_action() {
let mut up = LimMemQuasiNewtonUpdater {
update_type: UpdateType::Sr1,
..LimMemQuasiNewtonUpdater::default()
};
up.history.push(pair(&[1.0, 1.0], &[2.0, 5.0]));
let (v, u) = up.build_low_rank(1.0, 2);
assert_eq!(v.len(), 1, "positive denom routes to V");
assert!(u.is_empty());
let b = reconstruct_b(2, 1.0, &v, &u);
let bs = mat_vec(&b, 2, &[1.0, 1.0]);
assert!((bs[0] - 2.0).abs() < 1e-12);
assert!((bs[1] - 5.0).abs() < 1e-12);
}
#[test]
fn empty_history_yields_no_columns() {
let up = LimMemQuasiNewtonUpdater::new();
let (v, u) = up.build_low_rank(1.0, 4);
assert!(v.is_empty() && u.is_empty());
}
}