1use crate::hess::r#trait::HessianUpdater;
34use crate::ipopt_cq::IpoptCqHandle;
35use crate::ipopt_data::IpoptDataHandle;
36use pounce_common::types::{Index, Number};
37use pounce_linalg::compound_vector::CompoundVector;
38use pounce_linalg::dense_vector::{DenseVector, DenseVectorSpace};
39use pounce_linalg::low_rank_update_sym_matrix::LowRankUpdateSymMatrixSpace;
40use pounce_linalg::multi_vector_matrix::{MultiVectorMatrix, MultiVectorMatrixSpace};
41use pounce_linalg::Vector;
42use std::rc::Rc;
43
44#[derive(Debug, Clone)]
47pub struct CurvaturePair {
48 pub s: Rc<dyn Vector>,
49 pub y: Rc<dyn Vector>,
50 pub s_dot_y: Number,
51 pub s_norm: Number,
52 pub y_norm: Number,
53}
54
55#[derive(Debug, Clone, Copy, PartialEq, Eq)]
56pub enum UpdateType {
57 Bfgs,
58 Sr1,
59}
60
61#[derive(Debug, Clone, Copy, PartialEq, Eq)]
62pub enum InitialApprox {
63 Identity,
64 Scalar1,
65 Scalar2,
66}
67
68pub struct LimMemQuasiNewtonUpdater {
69 pub update_type: UpdateType,
70 pub initial_approx: InitialApprox,
71 pub max_history: i32,
72 pub init_val_max: Number,
77 pub init_val_min: Number,
78 pub history: Vec<CurvaturePair>,
81 pub last_x: Option<Rc<dyn Vector>>,
84 pub last_grad_f: Option<Rc<dyn Vector>>,
87 pub last_jac_c: Option<Rc<dyn pounce_linalg::matrix::Matrix>>,
93 pub last_jac_d: Option<Rc<dyn pounce_linalg::matrix::Matrix>>,
94}
95
96impl Default for LimMemQuasiNewtonUpdater {
97 fn default() -> Self {
98 Self {
99 update_type: UpdateType::Bfgs,
100 initial_approx: InitialApprox::Scalar2,
101 max_history: 6,
102 init_val_max: 1e8,
103 init_val_min: 1e-8,
104 history: Vec::new(),
105 last_x: None,
106 last_grad_f: None,
107 last_jac_c: None,
108 last_jac_d: None,
109 }
110 }
111}
112
113impl LimMemQuasiNewtonUpdater {
114 pub fn new() -> Self {
115 Self::default()
116 }
117
118 pub fn ingest_pair(&mut self, s: Rc<dyn Vector>, y: Rc<dyn Vector>) -> bool {
128 let s_dot_y = s.dot(&*y);
129 let s_norm = s.nrm2();
130 let y_norm = y.nrm2();
131 let accept = match self.update_type {
132 UpdateType::Bfgs => bfgs_curvature_pair_ok(s_dot_y, s_norm, y_norm),
133 UpdateType::Sr1 => {
134 sr1_denominator_ok(s_dot_y, s_norm, y_norm)
140 }
141 };
142 if !accept {
143 return false;
144 }
145 self.history.push(CurvaturePair {
146 s,
147 y,
148 s_dot_y,
149 s_norm,
150 y_norm,
151 });
152 while self.history.len() > self.max_history.max(0) as usize {
154 self.history.remove(0);
155 }
156 true
157 }
158}
159
160impl HessianUpdater for LimMemQuasiNewtonUpdater {
161 fn update_hessian(&mut self, data: &IpoptDataHandle, cq: &IpoptCqHandle) -> bool {
168 let (curr_x, curr_y_c, curr_y_d) = match data.borrow().curr.as_ref() {
169 Some(c) => (c.x.clone(), c.y_c.clone(), c.y_d.clone()),
170 None => return true,
171 };
172 let curr_grad_f = cq.borrow().curr_grad_f();
173 let curr_jac_c = cq.borrow().curr_jac_c();
174 let curr_jac_d = cq.borrow().curr_jac_d();
175
176 if let (Some(prev_x), Some(prev_grad_f), Some(prev_jac_c), Some(prev_jac_d)) = (
186 self.last_x.clone(),
187 self.last_grad_f.clone(),
188 self.last_jac_c.clone(),
189 self.last_jac_d.clone(),
190 ) {
191 let mut s = curr_x.make_new();
192 s.add_two_vectors(1.0, &*curr_x, -1.0, &*prev_x, 0.0);
193
194 let mut y = curr_x.make_new();
195 y.add_two_vectors(1.0, &*curr_grad_f, -1.0, &*prev_grad_f, 0.0);
197 curr_jac_c.trans_mult_vector(1.0, &*curr_y_c, 1.0, &mut *y);
199 prev_jac_c.trans_mult_vector(-1.0, &*curr_y_c, 1.0, &mut *y);
200 curr_jac_d.trans_mult_vector(1.0, &*curr_y_d, 1.0, &mut *y);
202 prev_jac_d.trans_mult_vector(-1.0, &*curr_y_d, 1.0, &mut *y);
203
204 self.ingest_pair(Rc::from(s), Rc::from(y));
205 }
206 self.last_x = Some(Rc::clone(&curr_x));
207 self.last_grad_f = Some(Rc::clone(&curr_grad_f));
208 self.last_jac_c = Some(Rc::clone(&curr_jac_c));
209 self.last_jac_d = Some(Rc::clone(&curr_jac_d));
210
211 let n_idx = curr_x.dim();
212 let nu = n_idx as usize;
213 let sigma = match self.update_type {
214 UpdateType::Bfgs => self.compute_sigma_bfgs(),
215 UpdateType::Sr1 => self.compute_sigma_bfgs(),
218 };
219
220 let (v_cols, u_cols) = self.build_low_rank(sigma, nu);
227
228 let col_space = DenseVectorSpace::new(n_idx);
239 let mut diag = curr_x.make_new();
240 diag.set(sigma);
241
242 let lr_space = LowRankUpdateSymMatrixSpace::new(n_idx, None, false);
243 let mut lr = lr_space.make_new_low_rank();
244 lr.set_diag(Rc::from(diag));
245 if let Some(mvm) = build_multi_vector(&col_space, curr_x.as_ref(), &v_cols) {
246 lr.set_v(Rc::new(mvm));
247 }
248 if let Some(mvm) = build_multi_vector(&col_space, curr_x.as_ref(), &u_cols) {
249 lr.set_u(Rc::new(mvm));
250 }
251
252 data.borrow_mut().w = Some(Rc::new(lr));
253 true
254 }
255}
256
257impl LimMemQuasiNewtonUpdater {
258 fn compute_sigma_bfgs(&self) -> Number {
259 if self.history.is_empty() {
260 return 1.0;
261 }
262 let last = self.history.last().unwrap();
263 let s_dot_s = last.s_norm * last.s_norm;
264 let y_dot_y = last.y_norm * last.y_norm;
265 initial_hessian_scalar(
266 self.initial_approx,
267 s_dot_s,
268 last.s_dot_y,
269 y_dot_y,
270 self.init_val_min,
271 self.init_val_max,
272 )
273 }
274
275 fn build_low_rank(&self, sigma: Number, n: usize) -> (Vec<Vec<Number>>, Vec<Vec<Number>>) {
288 let mut v_cols: Vec<Vec<Number>> = Vec::new();
289 let mut u_cols: Vec<Vec<Number>> = Vec::new();
290 if n == 0 {
291 return (v_cols, u_cols);
292 }
293 for pair in &self.history {
294 let s = dense_from_vec(pair.s.as_ref(), n);
295 let y = dense_from_vec(pair.y.as_ref(), n);
296
297 let mut bs: Vec<Number> = s.iter().map(|&si| sigma * si).collect();
299 for v in &v_cols {
300 let c: Number = (0..n).map(|i| v[i] * s[i]).sum();
301 for i in 0..n {
302 bs[i] += c * v[i];
303 }
304 }
305 for u in &u_cols {
306 let c: Number = (0..n).map(|i| u[i] * s[i]).sum();
307 for i in 0..n {
308 bs[i] -= c * u[i];
309 }
310 }
311
312 match self.update_type {
313 UpdateType::Bfgs => {
314 let s_bs: Number = (0..n).map(|i| s[i] * bs[i]).sum();
315 if s_bs <= 0.0 {
316 continue;
317 }
318 let sy = pair.s_dot_y;
319 let theta = powell_damping_theta(sy, s_bs);
320 let sr = theta * sy + (1.0 - theta) * s_bs;
321 if sr <= 0.0 {
322 continue;
323 }
324 let r_scale = 1.0 / sr.sqrt();
325 let bs_scale = 1.0 / s_bs.sqrt();
326 v_cols.push(
328 (0..n)
329 .map(|i| (theta * y[i] + (1.0 - theta) * bs[i]) * r_scale)
330 .collect(),
331 );
332 u_cols.push(bs.iter().map(|&bi| bi * bs_scale).collect());
334 }
335 UpdateType::Sr1 => {
336 let yms: Vec<Number> = (0..n).map(|i| y[i] - bs[i]).collect();
337 let denom: Number = (0..n).map(|i| yms[i] * s[i]).sum();
338 let yms_norm: Number = yms.iter().map(|&w| w * w).sum::<Number>().sqrt();
339 if !sr1_denominator_ok(denom, pair.s_norm, yms_norm) {
340 continue;
341 }
342 let scale = 1.0 / denom.abs().sqrt();
343 let col: Vec<Number> = yms.iter().map(|&w| w * scale).collect();
344 if denom > 0.0 {
345 v_cols.push(col);
346 } else {
347 u_cols.push(col);
348 }
349 }
350 }
351 }
352 (v_cols, u_cols)
353 }
354}
355
356fn build_multi_vector(
363 col_space: &Rc<DenseVectorSpace>,
364 template: &dyn Vector,
365 cols: &[Vec<Number>],
366) -> Option<MultiVectorMatrix> {
367 if cols.is_empty() {
368 return None;
369 }
370 let space = MultiVectorMatrixSpace::new(cols.len() as Index, Rc::clone(col_space));
371 let mut mvm = space.make_new_multi_vector();
372 for (k, col) in cols.iter().enumerate() {
373 let mut cv = template.make_new();
374 set_expanded(cv.as_mut(), col);
375 mvm.set_vector(k as Index, Rc::from(cv));
376 }
377 Some(mvm)
378}
379
380fn expanded_of(v: &dyn Vector) -> Vec<Number> {
384 if let Some(dv) = v.as_any().downcast_ref::<DenseVector>() {
385 return dv.expanded_values();
386 }
387 if let Some(cv) = v.as_any().downcast_ref::<CompoundVector>() {
388 let mut out = Vec::with_capacity(cv.dim() as usize);
389 for i in 0..cv.n_comps() {
390 out.extend(expanded_of(cv.comp(i)));
391 }
392 return out;
393 }
394 panic!("LimMemQuasiNewtonUpdater: unsupported primal vector type for expansion");
395}
396
397fn set_expanded(dst: &mut dyn Vector, flat: &[Number]) {
400 if let Some(dv) = dst.as_any_mut().downcast_mut::<DenseVector>() {
401 dv.set_values(flat);
402 return;
403 }
404 if let Some(cv) = dst.as_any_mut().downcast_mut::<CompoundVector>() {
405 let n = cv.n_comps();
406 let dims: Vec<usize> = (0..n).map(|i| cv.comp(i).dim() as usize).collect();
407 let mut off = 0usize;
408 for (i, &d) in dims.iter().enumerate() {
409 set_expanded(cv.comp_mut(i as Index), &flat[off..off + d]);
410 off += d;
411 }
412 return;
413 }
414 panic!("LimMemQuasiNewtonUpdater: unsupported primal vector type for set_expanded");
415}
416
417fn dense_from_vec(v: &dyn Vector, n: usize) -> Vec<Number> {
418 let ev = expanded_of(v);
419 debug_assert_eq!(ev.len(), n);
420 ev
421}
422
423pub fn initial_hessian_scalar(
435 init: InitialApprox,
436 s_dot_s: Number,
437 s_dot_y: Number,
438 y_dot_y: Number,
439 min_val: Number,
440 max_val: Number,
441) -> Number {
442 let raw = match init {
443 InitialApprox::Identity => 1.0,
444 InitialApprox::Scalar1 => {
445 if s_dot_s > 0.0 {
446 s_dot_y / s_dot_s
447 } else {
448 1.0
449 }
450 }
451 InitialApprox::Scalar2 => {
452 if s_dot_y > 0.0 {
453 y_dot_y / s_dot_y
454 } else {
455 1.0
456 }
457 }
458 };
459 raw.clamp(min_val, max_val)
460}
461
462pub fn powell_damping_theta(s_dot_y: Number, s_dot_b_s: Number) -> Number {
474 if s_dot_y >= 0.2 * s_dot_b_s {
475 1.0
476 } else {
477 let denom = s_dot_b_s - s_dot_y;
478 if denom > 0.0 {
479 0.8 * s_dot_b_s / denom
480 } else {
481 1.0
482 }
483 }
484}
485
486pub fn bfgs_curvature_pair_ok(s_dot_y: Number, s_norm: Number, y_norm: Number) -> bool {
490 let eps = 1e-8_f64;
491 s_dot_y > eps * s_norm * y_norm
492}
493
494pub fn sr1_denominator_ok(yms_dot_s: Number, s_norm: Number, yms_norm: Number) -> bool {
498 let eps = 1e-8_f64;
499 yms_dot_s.abs() > eps * s_norm * yms_norm
500}
501
502#[cfg(test)]
503mod tests {
504 use super::*;
505
506 #[test]
507 fn identity_init_returns_one() {
508 assert_eq!(
509 initial_hessian_scalar(InitialApprox::Identity, 1.0, 1.0, 1.0, 1e-8, 1e8),
510 1.0
511 );
512 }
513
514 #[test]
515 fn scalar1_init_is_sy_over_ss() {
516 let v = initial_hessian_scalar(InitialApprox::Scalar1, 4.0, 2.0, 0.0, 1e-8, 1e8);
518 assert!((v - 0.5).abs() < 1e-15);
519 }
520
521 #[test]
522 fn scalar2_init_is_yy_over_sy() {
523 let v = initial_hessian_scalar(InitialApprox::Scalar2, 0.0, 2.0, 8.0, 1e-8, 1e8);
525 assert!((v - 4.0).abs() < 1e-15);
526 }
527
528 #[test]
529 fn init_clamped_to_max() {
530 let v = initial_hessian_scalar(InitialApprox::Scalar2, 0.0, 1e-20, 1.0, 1e-8, 1e8);
531 assert_eq!(v, 1e8);
532 }
533
534 #[test]
535 fn init_clamped_to_min() {
536 let v = initial_hessian_scalar(InitialApprox::Scalar2, 0.0, 1e20, 1.0, 1e-8, 1e8);
537 assert_eq!(v, 1e-8);
538 }
539
540 #[test]
541 fn powell_no_damping_when_curvature_ok() {
542 assert_eq!(powell_damping_theta(1.0, 1.0), 1.0);
544 }
545
546 #[test]
547 fn powell_damps_when_curvature_violated() {
548 let theta = powell_damping_theta(0.1, 1.0);
550 assert!((theta - 8.0 / 9.0).abs() < 1e-15);
551 }
552
553 #[test]
554 fn bfgs_skip_criterion() {
555 assert!(bfgs_curvature_pair_ok(1.0, 1.0, 1.0));
557 assert!(!bfgs_curvature_pair_ok(1e-10, 1.0, 1.0));
559 }
560
561 #[test]
562 fn sr1_skip_criterion_uses_absolute_value() {
563 assert!(sr1_denominator_ok(-1.0, 1.0, 1.0));
565 assert!(!sr1_denominator_ok(1e-10, 1.0, 1.0));
566 }
567
568 fn rcv(values: &[Number]) -> Rc<dyn Vector> {
569 let mut v = pounce_linalg::dense_vector::DenseVectorSpace::new(values.len() as i32)
570 .make_new_dense();
571 v.set(0.0);
572 v.values_mut().copy_from_slice(values);
573 Rc::new(v)
574 }
575
576 #[test]
577 fn ingest_pair_accepts_well_curved_pair() {
578 let mut updater = LimMemQuasiNewtonUpdater::new();
579 let accepted = updater.ingest_pair(rcv(&[1.0, 0.0]), rcv(&[1.0, 0.0]));
581 assert!(accepted);
582 assert_eq!(updater.history.len(), 1);
583 let pair = &updater.history[0];
584 assert!((pair.s_dot_y - 1.0).abs() < 1e-15);
585 assert!((pair.s_norm - 1.0).abs() < 1e-15);
586 assert!((pair.y_norm - 1.0).abs() < 1e-15);
587 }
588
589 #[test]
590 fn ingest_pair_skips_zero_curvature() {
591 let mut updater = LimMemQuasiNewtonUpdater::new();
592 let accepted = updater.ingest_pair(rcv(&[1.0]), rcv(&[0.0]));
594 assert!(!accepted);
595 assert!(updater.history.is_empty());
596 }
597
598 #[test]
599 fn history_caps_at_max_history() {
600 let mut updater = LimMemQuasiNewtonUpdater {
601 max_history: 2,
602 ..LimMemQuasiNewtonUpdater::default()
603 };
604 for _ in 0..5 {
605 updater.ingest_pair(rcv(&[1.0]), rcv(&[1.0]));
606 }
607 assert_eq!(updater.history.len(), 2);
608 }
609
610 #[test]
611 fn sr1_path_routes_through_sr1_skip() {
612 let mut updater = LimMemQuasiNewtonUpdater {
613 update_type: UpdateType::Sr1,
614 ..LimMemQuasiNewtonUpdater::default()
615 };
616 assert!(updater.ingest_pair(rcv(&[1.0]), rcv(&[-1.0])));
618 }
619
620 fn pair(s: &[Number], y: &[Number]) -> CurvaturePair {
621 let s_rc = rcv(s);
622 let y_rc = rcv(y);
623 let s_dot_y = s_rc.dot(&*y_rc);
624 let s_norm = s_rc.nrm2();
625 let y_norm = y_rc.nrm2();
626 CurvaturePair {
627 s: s_rc,
628 y: y_rc,
629 s_dot_y,
630 s_norm,
631 y_norm,
632 }
633 }
634
635 fn reconstruct_b(n: usize, sigma: Number, v: &[Vec<Number>], u: &[Vec<Number>]) -> Vec<Number> {
638 let mut b = vec![0.0_f64; n * n];
639 for i in 0..n {
640 b[i * n + i] = sigma;
641 }
642 for col in v {
643 for i in 0..n {
644 for j in 0..n {
645 b[i * n + j] += col[i] * col[j];
646 }
647 }
648 }
649 for col in u {
650 for i in 0..n {
651 for j in 0..n {
652 b[i * n + j] -= col[i] * col[j];
653 }
654 }
655 }
656 b
657 }
658
659 fn mat_vec(b: &[Number], n: usize, x: &[Number]) -> Vec<Number> {
660 (0..n)
661 .map(|i| (0..n).map(|j| b[i * n + j] * x[j]).sum())
662 .collect()
663 }
664
665 #[test]
666 fn bfgs_low_rank_recovers_hessian_action() {
667 let mut up = LimMemQuasiNewtonUpdater::new();
672 up.history.push(pair(&[1.0, 1.0], &[2.0, 5.0]));
673 let (v, u) = up.build_low_rank(1.0, 2);
674 let b = reconstruct_b(2, 1.0, &v, &u);
675 let bs = mat_vec(&b, 2, &[1.0, 1.0]);
676 assert!((bs[0] - 2.0).abs() < 1e-12, "Bs[0]={}", bs[0]);
677 assert!((bs[1] - 5.0).abs() < 1e-12, "Bs[1]={}", bs[1]);
678 }
679
680 #[test]
681 fn bfgs_low_rank_keeps_symmetry() {
682 let mut up = LimMemQuasiNewtonUpdater::new();
683 up.history.push(pair(&[1.0, 0.5], &[2.0, 1.0]));
684 up.history.push(pair(&[0.7, 1.2], &[1.0, 2.5]));
685 let (v, u) = up.build_low_rank(3.0, 2);
686 let b = reconstruct_b(2, 3.0, &v, &u);
687 assert!((b[1] - b[2]).abs() < 1e-12);
689 }
690
691 #[test]
692 fn sr1_low_rank_recovers_hessian_action() {
693 let mut up = LimMemQuasiNewtonUpdater {
697 update_type: UpdateType::Sr1,
698 ..LimMemQuasiNewtonUpdater::default()
699 };
700 up.history.push(pair(&[1.0, 1.0], &[2.0, 5.0]));
701 let (v, u) = up.build_low_rank(1.0, 2);
702 assert_eq!(v.len(), 1, "positive denom routes to V");
703 assert!(u.is_empty());
704 let b = reconstruct_b(2, 1.0, &v, &u);
705 let bs = mat_vec(&b, 2, &[1.0, 1.0]);
706 assert!((bs[0] - 2.0).abs() < 1e-12);
707 assert!((bs[1] - 5.0).abs() < 1e-12);
708 }
709
710 #[test]
711 fn empty_history_yields_no_columns() {
712 let up = LimMemQuasiNewtonUpdater::new();
713 let (v, u) = up.build_low_rank(1.0, 4);
714 assert!(v.is_empty() && u.is_empty());
715 }
716}