use crate::sqp::qp_assembly::Triplet;
use pounce_common::types::{Index, Number};
use std::collections::VecDeque;
pub struct LBfgs {
n: usize,
m_history: usize,
pairs: VecDeque<(Vec<Number>, Vec<Number>)>,
prev_x: Option<Vec<Number>>,
prev_grad_lag: Option<Vec<Number>>,
h_irow: Vec<Index>,
h_jcol: Vec<Index>,
}
impl LBfgs {
pub fn new(n: usize, m_history: usize) -> Self {
debug_assert!(m_history >= 1);
let nz = n * (n + 1) / 2;
let mut h_irow = Vec::with_capacity(nz);
let mut h_jcol = Vec::with_capacity(nz);
for i in 0..n {
for j in 0..=i {
h_irow.push((i + 1) as Index);
h_jcol.push((j + 1) as Index);
}
}
Self {
n,
m_history,
pairs: VecDeque::with_capacity(m_history),
prev_x: None,
prev_grad_lag: None,
h_irow,
h_jcol,
}
}
pub fn has_prev(&self) -> bool {
self.prev_x.is_some()
}
pub fn update(&mut self, x_new: &[Number], grad_lag_new: &[Number]) {
assert_eq!(x_new.len(), self.n, "LBFGS::update: x_new.len() != n");
assert_eq!(
grad_lag_new.len(),
self.n,
"LBFGS::update: grad_lag_new.len() != n"
);
if let (Some(prev_x), Some(prev_grad_lag)) =
(self.prev_x.as_ref(), self.prev_grad_lag.as_ref())
{
let s: Vec<Number> = x_new
.iter()
.zip(prev_x.iter())
.map(|(a, b)| a - b)
.collect();
let y: Vec<Number> = grad_lag_new
.iter()
.zip(prev_grad_lag.iter())
.map(|(a, b)| a - b)
.collect();
let s_norm2: Number = s.iter().map(|v| v * v).sum();
if s_norm2 > 1e-30 {
if self.pairs.len() == self.m_history {
self.pairs.pop_front();
}
self.pairs.push_back((s, y));
}
}
self.prev_x = Some(x_new.to_vec());
self.prev_grad_lag = Some(grad_lag_new.to_vec());
}
pub fn as_triplet(&self) -> Triplet {
let b_dense = self.materialize();
let mut vals = Vec::with_capacity(self.h_irow.len());
for i in 0..self.n {
for j in 0..=i {
vals.push(b_dense[i * self.n + j]);
}
}
Triplet {
n_rows: self.n,
n_cols: self.n,
irow: self.h_irow.clone(),
jcol: self.h_jcol.clone(),
vals,
}
}
fn materialize(&self) -> Vec<Number> {
let n = self.n;
let gamma = self
.pairs
.back()
.and_then(|(s, y)| {
let sy: Number = s.iter().zip(y.iter()).map(|(a, b)| a * b).sum();
let yy: Number = y.iter().map(|v| v * v).sum();
if yy > 1e-30 && sy > 1e-30 {
Some(yy / sy)
} else {
None
}
})
.unwrap_or(1.0);
let mut b = vec![0.0_f64; n * n];
for i in 0..n {
b[i * n + i] = gamma;
}
for (s, y) in self.pairs.iter() {
let mut bs = vec![0.0_f64; n];
for i in 0..n {
let mut acc = 0.0_f64;
let row = &b[i * n..i * n + n];
for j in 0..n {
acc += row[j] * s[j];
}
bs[i] = acc;
}
let s_bs: Number = s.iter().zip(bs.iter()).map(|(a, b)| a * b).sum();
let s_y: Number = s.iter().zip(y.iter()).map(|(a, b)| a * b).sum();
let theta = if s_y >= 0.2 * s_bs {
1.0
} else if s_bs - s_y > 1e-14 {
0.8 * s_bs / (s_bs - s_y)
} else {
1.0
};
let y_damp: Vec<Number> = y
.iter()
.zip(bs.iter())
.map(|(yi, bsi)| theta * yi + (1.0 - theta) * bsi)
.collect();
let s_y_damp: Number = s.iter().zip(y_damp.iter()).map(|(a, b)| a * b).sum();
if s_bs > 1e-14 && s_y_damp > 1e-14 {
for i in 0..n {
for j in 0..n {
let delta = -(bs[i] * bs[j]) / s_bs + (y_damp[i] * y_damp[j]) / s_y_damp;
b[i * n + j] += delta;
}
}
}
}
b
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn lbfgs_seeds_identity_with_no_pairs() {
let lb = LBfgs::new(3, 5);
let t = lb.as_triplet();
for k in 0..t.vals.len() {
let i = (t.irow[k] - 1) as usize;
let j = (t.jcol[k] - 1) as usize;
let expected = if i == j { 1.0 } else { 0.0 };
assert!(
(t.vals[k] - expected).abs() < 1e-15,
"B[{i},{j}] = {} but expected {expected}",
t.vals[k]
);
}
}
#[test]
fn lbfgs_first_update_only_records_pair() {
let mut lb = LBfgs::new(2, 3);
lb.update(&[0.0, 0.0], &[1.0, 1.0]);
assert!(lb.has_prev());
assert!(lb.pairs.is_empty());
let t = lb.as_triplet();
let diag: Vec<_> = t
.vals
.iter()
.enumerate()
.filter(|(k, _)| t.irow[*k] == t.jcol[*k])
.map(|(_, v)| *v)
.collect();
assert!((diag[0] - 1.0).abs() < 1e-15);
assert!((diag[1] - 1.0).abs() < 1e-15);
}
#[test]
fn lbfgs_quadratic_recovers_exact_hessian_at_convergence() {
let mut lb = LBfgs::new(2, 5);
lb.update(&[0.0, 0.0], &[0.0, 0.0]); lb.update(&[1.0, 0.0], &[2.0, 0.0]); lb.update(&[1.0, 1.0], &[2.0, 4.0]); let t = lb.as_triplet();
let mut b = [[0.0_f64; 2]; 2];
for k in 0..t.vals.len() {
let i = (t.irow[k] - 1) as usize;
let j = (t.jcol[k] - 1) as usize;
b[i][j] = t.vals[k];
if i != j {
b[j][i] = t.vals[k];
}
}
assert!((b[0][0] - 2.0).abs() < 1e-9, "B[0,0] = {}", b[0][0]);
assert!((b[1][1] - 4.0).abs() < 1e-9, "B[1,1] = {}", b[1][1]);
assert!(b[0][1].abs() < 1e-9, "B[0,1] = {}", b[0][1]);
}
#[test]
fn lbfgs_history_cap_drops_oldest() {
let mut lb = LBfgs::new(2, 2);
lb.update(&[0.0, 0.0], &[0.0, 0.0]);
lb.update(&[1.0, 0.0], &[1.0, 0.0]);
lb.update(&[2.0, 0.0], &[2.0, 0.0]);
lb.update(&[2.0, 1.0], &[2.0, 1.0]);
assert_eq!(lb.pairs.len(), 2);
}
}