1use crate::hess::r#trait::HessianUpdater;
27use crate::ipopt_cq::IpoptCqHandle;
28use crate::ipopt_data::IpoptDataHandle;
29use pounce_common::types::{Index, Number};
30use pounce_linalg::dense_vector::DenseVector;
31use pounce_linalg::triplet::{SymTMatrix, SymTMatrixSpace};
32use pounce_linalg::Vector;
33use std::rc::Rc;
34
35#[derive(Debug, Clone)]
38pub struct CurvaturePair {
39 pub s: Rc<dyn Vector>,
40 pub y: Rc<dyn Vector>,
41 pub s_dot_y: Number,
42 pub s_norm: Number,
43 pub y_norm: Number,
44}
45
46#[derive(Debug, Clone, Copy, PartialEq, Eq)]
47pub enum UpdateType {
48 Bfgs,
49 Sr1,
50}
51
52#[derive(Debug, Clone, Copy, PartialEq, Eq)]
53pub enum InitialApprox {
54 Identity,
55 Scalar1,
56 Scalar2,
57}
58
59pub struct LimMemQuasiNewtonUpdater {
60 pub update_type: UpdateType,
61 pub initial_approx: InitialApprox,
62 pub max_history: i32,
63 pub init_val_max: Number,
68 pub init_val_min: Number,
69 pub history: Vec<CurvaturePair>,
72 pub last_x: Option<Rc<dyn Vector>>,
75 pub last_grad_f: Option<Rc<dyn Vector>>,
78 pub last_jac_c: Option<Rc<dyn pounce_linalg::matrix::Matrix>>,
84 pub last_jac_d: Option<Rc<dyn pounce_linalg::matrix::Matrix>>,
85 pub h_space: Option<Rc<SymTMatrixSpace>>,
94}
95
96impl Default for LimMemQuasiNewtonUpdater {
97 fn default() -> Self {
98 Self {
99 update_type: UpdateType::Bfgs,
100 initial_approx: InitialApprox::Scalar2,
101 max_history: 6,
102 init_val_max: 1e8,
103 init_val_min: 1e-8,
104 history: Vec::new(),
105 last_x: None,
106 last_grad_f: None,
107 last_jac_c: None,
108 last_jac_d: None,
109 h_space: None,
110 }
111 }
112}
113
114impl LimMemQuasiNewtonUpdater {
115 pub fn new() -> Self {
116 Self::default()
117 }
118
119 pub fn ingest_pair(&mut self, s: Rc<dyn Vector>, y: Rc<dyn Vector>) -> bool {
129 let s_dot_y = s.dot(&*y);
130 let s_norm = s.nrm2();
131 let y_norm = y.nrm2();
132 let accept = match self.update_type {
133 UpdateType::Bfgs => bfgs_curvature_pair_ok(s_dot_y, s_norm, y_norm),
134 UpdateType::Sr1 => {
135 sr1_denominator_ok(s_dot_y, s_norm, y_norm)
141 }
142 };
143 if !accept {
144 return false;
145 }
146 self.history.push(CurvaturePair {
147 s,
148 y,
149 s_dot_y,
150 s_norm,
151 y_norm,
152 });
153 while self.history.len() > self.max_history.max(0) as usize {
155 self.history.remove(0);
156 }
157 true
158 }
159}
160
161impl HessianUpdater for LimMemQuasiNewtonUpdater {
162 fn update_hessian(&mut self, data: &IpoptDataHandle, cq: &IpoptCqHandle) -> bool {
170 let (curr_x, curr_y_c, curr_y_d) = match data.borrow().curr.as_ref() {
171 Some(c) => (c.x.clone(), c.y_c.clone(), c.y_d.clone()),
172 None => return true,
173 };
174 let curr_grad_f = cq.borrow().curr_grad_f();
175 let curr_jac_c = cq.borrow().curr_jac_c();
176 let curr_jac_d = cq.borrow().curr_jac_d();
177
178 if self.h_space.is_none() {
183 let h = cq.borrow().curr_exact_hessian();
184 if let Some(symt) = h.as_any().downcast_ref::<SymTMatrix>() {
185 self.h_space = Some(Rc::clone(symt.space()));
186 } else {
187 self.h_space = Some(make_full_lower_triangle_space(curr_x.dim()));
188 }
189 }
190
191 if let (Some(prev_x), Some(prev_grad_f), Some(prev_jac_c), Some(prev_jac_d)) = (
201 self.last_x.clone(),
202 self.last_grad_f.clone(),
203 self.last_jac_c.clone(),
204 self.last_jac_d.clone(),
205 ) {
206 let mut s = curr_x.make_new();
207 s.add_two_vectors(1.0, &*curr_x, -1.0, &*prev_x, 0.0);
208
209 let mut y = curr_x.make_new();
210 y.add_two_vectors(1.0, &*curr_grad_f, -1.0, &*prev_grad_f, 0.0);
212 curr_jac_c.trans_mult_vector(1.0, &*curr_y_c, 1.0, &mut *y);
214 prev_jac_c.trans_mult_vector(-1.0, &*curr_y_c, 1.0, &mut *y);
215 curr_jac_d.trans_mult_vector(1.0, &*curr_y_d, 1.0, &mut *y);
217 prev_jac_d.trans_mult_vector(-1.0, &*curr_y_d, 1.0, &mut *y);
218
219 self.ingest_pair(Rc::from(s), Rc::from(y));
220 }
221 self.last_x = Some(Rc::clone(&curr_x));
222 self.last_grad_f = Some(Rc::clone(&curr_grad_f));
223 self.last_jac_c = Some(Rc::clone(&curr_jac_c));
224 self.last_jac_d = Some(Rc::clone(&curr_jac_d));
225
226 let n = curr_x.dim() as usize;
227 let sigma = match self.update_type {
228 UpdateType::Bfgs => self.compute_sigma_bfgs(),
229 UpdateType::Sr1 => self.compute_sigma_bfgs(),
233 };
234 let mut dense = vec![0.0_f64; n * n];
235 for i in 0..n {
236 dense[i * n + i] = sigma;
237 }
238 match self.update_type {
239 UpdateType::Bfgs => apply_bfgs_history(&mut dense, n, &self.history),
240 UpdateType::Sr1 => apply_sr1_history(&mut dense, n, &self.history),
241 }
242
243 let space = Rc::clone(self.h_space.as_ref().unwrap());
244 let mut mat = SymTMatrix::new(Rc::clone(&space));
245 let irows = space.irows();
246 let jcols = space.jcols();
247 let mut vals = vec![0.0_f64; irows.len()];
248 for k in 0..irows.len() {
249 let i = (irows[k] - 1) as usize;
251 let j = (jcols[k] - 1) as usize;
252 vals[k] = dense[i * n + j];
253 }
254 mat.set_values(&vals);
255 data.borrow_mut().w = Some(Rc::new(mat));
256 true
257 }
258}
259
260impl LimMemQuasiNewtonUpdater {
261 fn compute_sigma_bfgs(&self) -> Number {
262 if self.history.is_empty() {
263 return 1.0;
264 }
265 let last = self.history.last().unwrap();
266 let s_dot_s = last.s_norm * last.s_norm;
267 let y_dot_y = last.y_norm * last.y_norm;
268 initial_hessian_scalar(
269 self.initial_approx,
270 s_dot_s,
271 last.s_dot_y,
272 y_dot_y,
273 self.init_val_min,
274 self.init_val_max,
275 )
276 }
277}
278
279fn make_full_lower_triangle_space(n: Index) -> Rc<SymTMatrixSpace> {
284 let nz = (n as usize) * ((n as usize) + 1) / 2;
285 let mut irows: Vec<Index> = Vec::with_capacity(nz);
286 let mut jcols: Vec<Index> = Vec::with_capacity(nz);
287 for i in 1..=n {
288 for j in 1..=i {
289 irows.push(i);
290 jcols.push(j);
291 }
292 }
293 SymTMatrixSpace::new(n, irows, jcols)
294}
295
296fn dense_from_vec(v: &dyn Vector, n: usize) -> Vec<Number> {
297 if let Some(dv) = v.as_any().downcast_ref::<DenseVector>() {
298 let ev = dv.expanded_values();
299 debug_assert_eq!(ev.len(), n);
300 return ev;
301 }
302 panic!("LimMemQuasiNewtonUpdater: curvature pairs must be DenseVector-backed");
303}
304
305fn apply_bfgs_history(b: &mut [Number], n: usize, history: &[CurvaturePair]) {
313 if n == 0 {
314 return;
315 }
316 let mut bs = vec![0.0_f64; n];
317 let mut r = vec![0.0_f64; n];
318 for pair in history {
319 let s = dense_from_vec(pair.s.as_ref(), n);
320 let y = dense_from_vec(pair.y.as_ref(), n);
321 for i in 0..n {
323 let row = &b[i * n..(i + 1) * n];
324 let mut acc = 0.0;
325 for j in 0..n {
326 acc += row[j] * s[j];
327 }
328 bs[i] = acc;
329 }
330 let s_bs: Number = (0..n).map(|i| s[i] * bs[i]).sum();
331 if s_bs <= 0.0 {
332 continue;
333 }
334 let sy = pair.s_dot_y;
335 let theta = powell_damping_theta(sy, s_bs);
336 for i in 0..n {
337 r[i] = theta * y[i] + (1.0 - theta) * bs[i];
338 }
339 let sr: Number = theta * sy + (1.0 - theta) * s_bs;
340 if sr <= 0.0 {
341 continue;
342 }
343 for i in 0..n {
344 let r_i = r[i];
345 let bs_i = bs[i];
346 let row = &mut b[i * n..(i + 1) * n];
347 for j in 0..n {
348 row[j] += r_i * r[j] / sr - bs_i * bs[j] / s_bs;
349 }
350 }
351 }
352}
353
354fn apply_sr1_history(b: &mut [Number], n: usize, history: &[CurvaturePair]) {
360 if n == 0 {
361 return;
362 }
363 let mut bs = vec![0.0_f64; n];
364 let mut yms = vec![0.0_f64; n];
365 for pair in history {
366 let s = dense_from_vec(pair.s.as_ref(), n);
367 let y = dense_from_vec(pair.y.as_ref(), n);
368 for i in 0..n {
369 let row = &b[i * n..(i + 1) * n];
370 let mut acc = 0.0;
371 for j in 0..n {
372 acc += row[j] * s[j];
373 }
374 bs[i] = acc;
375 }
376 for i in 0..n {
377 yms[i] = y[i] - bs[i];
378 }
379 let denom: Number = (0..n).map(|i| yms[i] * s[i]).sum();
380 let yms_norm: Number = (0..n).map(|i| yms[i] * yms[i]).sum::<Number>().sqrt();
381 if !sr1_denominator_ok(denom, pair.s_norm, yms_norm) {
382 continue;
383 }
384 for i in 0..n {
385 let yms_i = yms[i];
386 let row = &mut b[i * n..(i + 1) * n];
387 for j in 0..n {
388 row[j] += yms_i * yms[j] / denom;
389 }
390 }
391 }
392}
393
394pub fn initial_hessian_scalar(
406 init: InitialApprox,
407 s_dot_s: Number,
408 s_dot_y: Number,
409 y_dot_y: Number,
410 min_val: Number,
411 max_val: Number,
412) -> Number {
413 let raw = match init {
414 InitialApprox::Identity => 1.0,
415 InitialApprox::Scalar1 => {
416 if s_dot_s > 0.0 {
417 s_dot_y / s_dot_s
418 } else {
419 1.0
420 }
421 }
422 InitialApprox::Scalar2 => {
423 if s_dot_y > 0.0 {
424 y_dot_y / s_dot_y
425 } else {
426 1.0
427 }
428 }
429 };
430 raw.clamp(min_val, max_val)
431}
432
433pub fn powell_damping_theta(s_dot_y: Number, s_dot_b_s: Number) -> Number {
445 if s_dot_y >= 0.2 * s_dot_b_s {
446 1.0
447 } else {
448 let denom = s_dot_b_s - s_dot_y;
449 if denom > 0.0 {
450 0.8 * s_dot_b_s / denom
451 } else {
452 1.0
453 }
454 }
455}
456
457pub fn bfgs_curvature_pair_ok(s_dot_y: Number, s_norm: Number, y_norm: Number) -> bool {
461 let eps = 1e-8_f64;
462 s_dot_y > eps * s_norm * y_norm
463}
464
465pub fn sr1_denominator_ok(yms_dot_s: Number, s_norm: Number, yms_norm: Number) -> bool {
469 let eps = 1e-8_f64;
470 yms_dot_s.abs() > eps * s_norm * yms_norm
471}
472
473#[cfg(test)]
474mod tests {
475 use super::*;
476
477 #[test]
478 fn identity_init_returns_one() {
479 assert_eq!(
480 initial_hessian_scalar(InitialApprox::Identity, 1.0, 1.0, 1.0, 1e-8, 1e8),
481 1.0
482 );
483 }
484
485 #[test]
486 fn scalar1_init_is_sy_over_ss() {
487 let v = initial_hessian_scalar(InitialApprox::Scalar1, 4.0, 2.0, 0.0, 1e-8, 1e8);
489 assert!((v - 0.5).abs() < 1e-15);
490 }
491
492 #[test]
493 fn scalar2_init_is_yy_over_sy() {
494 let v = initial_hessian_scalar(InitialApprox::Scalar2, 0.0, 2.0, 8.0, 1e-8, 1e8);
496 assert!((v - 4.0).abs() < 1e-15);
497 }
498
499 #[test]
500 fn init_clamped_to_max() {
501 let v = initial_hessian_scalar(InitialApprox::Scalar2, 0.0, 1e-20, 1.0, 1e-8, 1e8);
502 assert_eq!(v, 1e8);
503 }
504
505 #[test]
506 fn init_clamped_to_min() {
507 let v = initial_hessian_scalar(InitialApprox::Scalar2, 0.0, 1e20, 1.0, 1e-8, 1e8);
508 assert_eq!(v, 1e-8);
509 }
510
511 #[test]
512 fn powell_no_damping_when_curvature_ok() {
513 assert_eq!(powell_damping_theta(1.0, 1.0), 1.0);
515 }
516
517 #[test]
518 fn powell_damps_when_curvature_violated() {
519 let theta = powell_damping_theta(0.1, 1.0);
521 assert!((theta - 8.0 / 9.0).abs() < 1e-15);
522 }
523
524 #[test]
525 fn bfgs_skip_criterion() {
526 assert!(bfgs_curvature_pair_ok(1.0, 1.0, 1.0));
528 assert!(!bfgs_curvature_pair_ok(1e-10, 1.0, 1.0));
530 }
531
532 #[test]
533 fn sr1_skip_criterion_uses_absolute_value() {
534 assert!(sr1_denominator_ok(-1.0, 1.0, 1.0));
536 assert!(!sr1_denominator_ok(1e-10, 1.0, 1.0));
537 }
538
539 fn rcv(values: &[Number]) -> Rc<dyn Vector> {
540 let mut v = pounce_linalg::dense_vector::DenseVectorSpace::new(values.len() as i32)
541 .make_new_dense();
542 v.set(0.0);
543 v.values_mut().copy_from_slice(values);
544 Rc::new(v)
545 }
546
547 #[test]
548 fn ingest_pair_accepts_well_curved_pair() {
549 let mut updater = LimMemQuasiNewtonUpdater::new();
550 let accepted = updater.ingest_pair(rcv(&[1.0, 0.0]), rcv(&[1.0, 0.0]));
552 assert!(accepted);
553 assert_eq!(updater.history.len(), 1);
554 let pair = &updater.history[0];
555 assert!((pair.s_dot_y - 1.0).abs() < 1e-15);
556 assert!((pair.s_norm - 1.0).abs() < 1e-15);
557 assert!((pair.y_norm - 1.0).abs() < 1e-15);
558 }
559
560 #[test]
561 fn ingest_pair_skips_zero_curvature() {
562 let mut updater = LimMemQuasiNewtonUpdater::new();
563 let accepted = updater.ingest_pair(rcv(&[1.0]), rcv(&[0.0]));
565 assert!(!accepted);
566 assert!(updater.history.is_empty());
567 }
568
569 #[test]
570 fn history_caps_at_max_history() {
571 let mut updater = LimMemQuasiNewtonUpdater {
572 max_history: 2,
573 ..LimMemQuasiNewtonUpdater::default()
574 };
575 for _ in 0..5 {
576 updater.ingest_pair(rcv(&[1.0]), rcv(&[1.0]));
577 }
578 assert_eq!(updater.history.len(), 2);
579 }
580
581 #[test]
582 fn sr1_path_routes_through_sr1_skip() {
583 let mut updater = LimMemQuasiNewtonUpdater {
584 update_type: UpdateType::Sr1,
585 ..LimMemQuasiNewtonUpdater::default()
586 };
587 assert!(updater.ingest_pair(rcv(&[1.0]), rcv(&[-1.0])));
589 }
590
591 fn pair(s: &[Number], y: &[Number]) -> CurvaturePair {
592 let s_rc = rcv(s);
593 let y_rc = rcv(y);
594 let s_dot_y = s_rc.dot(&*y_rc);
595 let s_norm = s_rc.nrm2();
596 let y_norm = y_rc.nrm2();
597 CurvaturePair {
598 s: s_rc,
599 y: y_rc,
600 s_dot_y,
601 s_norm,
602 y_norm,
603 }
604 }
605
606 #[test]
607 fn bfgs_quadratic_recovers_hessian() {
608 let mut b = vec![1.0, 0.0, 0.0, 1.0]; let p = pair(&[1.0, 1.0], &[2.0, 5.0]);
615 apply_bfgs_history(&mut b, 2, std::slice::from_ref(&p));
616 let bs0 = b[0] * 1.0 + b[1] * 1.0;
618 let bs1 = b[2] * 1.0 + b[3] * 1.0;
619 assert!((bs0 - 2.0).abs() < 1e-12, "Bs[0]={}", bs0);
620 assert!((bs1 - 5.0).abs() < 1e-12, "Bs[1]={}", bs1);
621 }
622
623 #[test]
624 fn bfgs_history_keeps_symmetry() {
625 let mut b = vec![3.0, 0.0, 0.0, 3.0];
626 let pairs = vec![
627 pair(&[1.0, 0.5], &[2.0, 1.0]),
628 pair(&[0.7, 1.2], &[1.0, 2.5]),
629 ];
630 apply_bfgs_history(&mut b, 2, &pairs);
631 assert!((b[1] - b[2]).abs() < 1e-12);
632 }
633
634 #[test]
635 fn sr1_quadratic_one_pair_recovers_hessian_action() {
636 let mut b = vec![1.0, 0.0, 0.0, 1.0];
641 let p = pair(&[1.0, 1.0], &[2.0, 5.0]);
642 apply_sr1_history(&mut b, 2, std::slice::from_ref(&p));
643 let bs0 = b[0] + b[1];
644 let bs1 = b[2] + b[3];
645 assert!((bs0 - 2.0).abs() < 1e-12);
646 assert!((bs1 - 5.0).abs() < 1e-12);
647 }
648
649 #[test]
650 fn full_lower_triangle_space_has_n_n_plus_1_over_2() {
651 let s = make_full_lower_triangle_space(4);
652 assert_eq!(s.dim(), 4);
653 assert_eq!(s.nonzeros(), 10);
654 assert_eq!(s.irows()[0], 1);
656 assert_eq!(s.jcols()[0], 1);
657 assert_eq!(s.irows()[1], 2);
658 assert_eq!(s.jcols()[1], 1);
659 assert_eq!(s.irows()[2], 2);
660 assert_eq!(s.jcols()[2], 2);
661 }
662}