1use crate::hess::r#trait::HessianUpdater;
34use crate::ipopt_cq::IpoptCqHandle;
35use crate::ipopt_data::IpoptDataHandle;
36use pounce_common::types::{Index, Number};
37use pounce_linalg::dense_vector::{DenseVector, DenseVectorSpace};
38use pounce_linalg::low_rank_update_sym_matrix::LowRankUpdateSymMatrixSpace;
39use pounce_linalg::multi_vector_matrix::{MultiVectorMatrix, MultiVectorMatrixSpace};
40use pounce_linalg::Vector;
41use std::rc::Rc;
42
43#[derive(Debug, Clone)]
46pub struct CurvaturePair {
47 pub s: Rc<dyn Vector>,
48 pub y: Rc<dyn Vector>,
49 pub s_dot_y: Number,
50 pub s_norm: Number,
51 pub y_norm: Number,
52}
53
54#[derive(Debug, Clone, Copy, PartialEq, Eq)]
55pub enum UpdateType {
56 Bfgs,
57 Sr1,
58}
59
60#[derive(Debug, Clone, Copy, PartialEq, Eq)]
61pub enum InitialApprox {
62 Identity,
63 Scalar1,
64 Scalar2,
65}
66
67pub struct LimMemQuasiNewtonUpdater {
68 pub update_type: UpdateType,
69 pub initial_approx: InitialApprox,
70 pub max_history: i32,
71 pub init_val_max: Number,
76 pub init_val_min: Number,
77 pub history: Vec<CurvaturePair>,
80 pub last_x: Option<Rc<dyn Vector>>,
83 pub last_grad_f: Option<Rc<dyn Vector>>,
86 pub last_jac_c: Option<Rc<dyn pounce_linalg::matrix::Matrix>>,
92 pub last_jac_d: Option<Rc<dyn pounce_linalg::matrix::Matrix>>,
93}
94
95impl Default for LimMemQuasiNewtonUpdater {
96 fn default() -> Self {
97 Self {
98 update_type: UpdateType::Bfgs,
99 initial_approx: InitialApprox::Scalar2,
100 max_history: 6,
101 init_val_max: 1e8,
102 init_val_min: 1e-8,
103 history: Vec::new(),
104 last_x: None,
105 last_grad_f: None,
106 last_jac_c: None,
107 last_jac_d: None,
108 }
109 }
110}
111
112impl LimMemQuasiNewtonUpdater {
113 pub fn new() -> Self {
114 Self::default()
115 }
116
117 pub fn ingest_pair(&mut self, s: Rc<dyn Vector>, y: Rc<dyn Vector>) -> bool {
127 let s_dot_y = s.dot(&*y);
128 let s_norm = s.nrm2();
129 let y_norm = y.nrm2();
130 let accept = match self.update_type {
131 UpdateType::Bfgs => bfgs_curvature_pair_ok(s_dot_y, s_norm, y_norm),
132 UpdateType::Sr1 => {
133 sr1_denominator_ok(s_dot_y, s_norm, y_norm)
139 }
140 };
141 if !accept {
142 return false;
143 }
144 self.history.push(CurvaturePair {
145 s,
146 y,
147 s_dot_y,
148 s_norm,
149 y_norm,
150 });
151 while self.history.len() > self.max_history.max(0) as usize {
153 self.history.remove(0);
154 }
155 true
156 }
157}
158
159impl HessianUpdater for LimMemQuasiNewtonUpdater {
160 fn update_hessian(&mut self, data: &IpoptDataHandle, cq: &IpoptCqHandle) -> bool {
167 let (curr_x, curr_y_c, curr_y_d) = match data.borrow().curr.as_ref() {
168 Some(c) => (c.x.clone(), c.y_c.clone(), c.y_d.clone()),
169 None => return true,
170 };
171 let curr_grad_f = cq.borrow().curr_grad_f();
172 let curr_jac_c = cq.borrow().curr_jac_c();
173 let curr_jac_d = cq.borrow().curr_jac_d();
174
175 if let (Some(prev_x), Some(prev_grad_f), Some(prev_jac_c), Some(prev_jac_d)) = (
185 self.last_x.clone(),
186 self.last_grad_f.clone(),
187 self.last_jac_c.clone(),
188 self.last_jac_d.clone(),
189 ) {
190 let mut s = curr_x.make_new();
191 s.add_two_vectors(1.0, &*curr_x, -1.0, &*prev_x, 0.0);
192
193 let mut y = curr_x.make_new();
194 y.add_two_vectors(1.0, &*curr_grad_f, -1.0, &*prev_grad_f, 0.0);
196 curr_jac_c.trans_mult_vector(1.0, &*curr_y_c, 1.0, &mut *y);
198 prev_jac_c.trans_mult_vector(-1.0, &*curr_y_c, 1.0, &mut *y);
199 curr_jac_d.trans_mult_vector(1.0, &*curr_y_d, 1.0, &mut *y);
201 prev_jac_d.trans_mult_vector(-1.0, &*curr_y_d, 1.0, &mut *y);
202
203 self.ingest_pair(Rc::from(s), Rc::from(y));
204 }
205 self.last_x = Some(Rc::clone(&curr_x));
206 self.last_grad_f = Some(Rc::clone(&curr_grad_f));
207 self.last_jac_c = Some(Rc::clone(&curr_jac_c));
208 self.last_jac_d = Some(Rc::clone(&curr_jac_d));
209
210 let n_idx = curr_x.dim();
211 let nu = n_idx as usize;
212 let sigma = match self.update_type {
213 UpdateType::Bfgs => self.compute_sigma_bfgs(),
214 UpdateType::Sr1 => self.compute_sigma_bfgs(),
217 };
218
219 let (v_cols, u_cols) = self.build_low_rank(sigma, nu);
226
227 let col_space = DenseVectorSpace::new(n_idx);
228 let mut diag = col_space.make_new_dense();
229 diag.set_values(&vec![sigma; nu]);
230
231 let lr_space = LowRankUpdateSymMatrixSpace::new(n_idx, None, false);
232 let mut lr = lr_space.make_new_low_rank();
233 lr.set_diag(Rc::new(diag) as Rc<dyn Vector>);
234 if let Some(mvm) = build_multi_vector(&col_space, &v_cols) {
235 lr.set_v(Rc::new(mvm));
236 }
237 if let Some(mvm) = build_multi_vector(&col_space, &u_cols) {
238 lr.set_u(Rc::new(mvm));
239 }
240
241 data.borrow_mut().w = Some(Rc::new(lr));
242 true
243 }
244}
245
246impl LimMemQuasiNewtonUpdater {
247 fn compute_sigma_bfgs(&self) -> Number {
248 if self.history.is_empty() {
249 return 1.0;
250 }
251 let last = self.history.last().unwrap();
252 let s_dot_s = last.s_norm * last.s_norm;
253 let y_dot_y = last.y_norm * last.y_norm;
254 initial_hessian_scalar(
255 self.initial_approx,
256 s_dot_s,
257 last.s_dot_y,
258 y_dot_y,
259 self.init_val_min,
260 self.init_val_max,
261 )
262 }
263
264 fn build_low_rank(&self, sigma: Number, n: usize) -> (Vec<Vec<Number>>, Vec<Vec<Number>>) {
277 let mut v_cols: Vec<Vec<Number>> = Vec::new();
278 let mut u_cols: Vec<Vec<Number>> = Vec::new();
279 if n == 0 {
280 return (v_cols, u_cols);
281 }
282 for pair in &self.history {
283 let s = dense_from_vec(pair.s.as_ref(), n);
284 let y = dense_from_vec(pair.y.as_ref(), n);
285
286 let mut bs: Vec<Number> = s.iter().map(|&si| sigma * si).collect();
288 for v in &v_cols {
289 let c: Number = (0..n).map(|i| v[i] * s[i]).sum();
290 for i in 0..n {
291 bs[i] += c * v[i];
292 }
293 }
294 for u in &u_cols {
295 let c: Number = (0..n).map(|i| u[i] * s[i]).sum();
296 for i in 0..n {
297 bs[i] -= c * u[i];
298 }
299 }
300
301 match self.update_type {
302 UpdateType::Bfgs => {
303 let s_bs: Number = (0..n).map(|i| s[i] * bs[i]).sum();
304 if s_bs <= 0.0 {
305 continue;
306 }
307 let sy = pair.s_dot_y;
308 let theta = powell_damping_theta(sy, s_bs);
309 let sr = theta * sy + (1.0 - theta) * s_bs;
310 if sr <= 0.0 {
311 continue;
312 }
313 let r_scale = 1.0 / sr.sqrt();
314 let bs_scale = 1.0 / s_bs.sqrt();
315 v_cols.push(
317 (0..n)
318 .map(|i| (theta * y[i] + (1.0 - theta) * bs[i]) * r_scale)
319 .collect(),
320 );
321 u_cols.push(bs.iter().map(|&bi| bi * bs_scale).collect());
323 }
324 UpdateType::Sr1 => {
325 let yms: Vec<Number> = (0..n).map(|i| y[i] - bs[i]).collect();
326 let denom: Number = (0..n).map(|i| yms[i] * s[i]).sum();
327 let yms_norm: Number = yms.iter().map(|&w| w * w).sum::<Number>().sqrt();
328 if !sr1_denominator_ok(denom, pair.s_norm, yms_norm) {
329 continue;
330 }
331 let scale = 1.0 / denom.abs().sqrt();
332 let col: Vec<Number> = yms.iter().map(|&w| w * scale).collect();
333 if denom > 0.0 {
334 v_cols.push(col);
335 } else {
336 u_cols.push(col);
337 }
338 }
339 }
340 }
341 (v_cols, u_cols)
342 }
343}
344
345fn build_multi_vector(
349 col_space: &Rc<DenseVectorSpace>,
350 cols: &[Vec<Number>],
351) -> Option<MultiVectorMatrix> {
352 if cols.is_empty() {
353 return None;
354 }
355 let space = MultiVectorMatrixSpace::new(cols.len() as Index, Rc::clone(col_space));
356 let mut mvm = space.make_new_multi_vector();
357 for (k, col) in cols.iter().enumerate() {
358 let mut cv = col_space.make_new_dense();
359 cv.set_values(col);
360 mvm.set_vector(k as Index, Rc::new(cv) as Rc<dyn Vector>);
361 }
362 Some(mvm)
363}
364
365fn dense_from_vec(v: &dyn Vector, n: usize) -> Vec<Number> {
366 if let Some(dv) = v.as_any().downcast_ref::<DenseVector>() {
367 let ev = dv.expanded_values();
368 debug_assert_eq!(ev.len(), n);
369 return ev;
370 }
371 panic!("LimMemQuasiNewtonUpdater: curvature pairs must be DenseVector-backed");
372}
373
374pub fn initial_hessian_scalar(
386 init: InitialApprox,
387 s_dot_s: Number,
388 s_dot_y: Number,
389 y_dot_y: Number,
390 min_val: Number,
391 max_val: Number,
392) -> Number {
393 let raw = match init {
394 InitialApprox::Identity => 1.0,
395 InitialApprox::Scalar1 => {
396 if s_dot_s > 0.0 {
397 s_dot_y / s_dot_s
398 } else {
399 1.0
400 }
401 }
402 InitialApprox::Scalar2 => {
403 if s_dot_y > 0.0 {
404 y_dot_y / s_dot_y
405 } else {
406 1.0
407 }
408 }
409 };
410 raw.clamp(min_val, max_val)
411}
412
413pub fn powell_damping_theta(s_dot_y: Number, s_dot_b_s: Number) -> Number {
425 if s_dot_y >= 0.2 * s_dot_b_s {
426 1.0
427 } else {
428 let denom = s_dot_b_s - s_dot_y;
429 if denom > 0.0 {
430 0.8 * s_dot_b_s / denom
431 } else {
432 1.0
433 }
434 }
435}
436
437pub fn bfgs_curvature_pair_ok(s_dot_y: Number, s_norm: Number, y_norm: Number) -> bool {
441 let eps = 1e-8_f64;
442 s_dot_y > eps * s_norm * y_norm
443}
444
445pub fn sr1_denominator_ok(yms_dot_s: Number, s_norm: Number, yms_norm: Number) -> bool {
449 let eps = 1e-8_f64;
450 yms_dot_s.abs() > eps * s_norm * yms_norm
451}
452
453#[cfg(test)]
454mod tests {
455 use super::*;
456
457 #[test]
458 fn identity_init_returns_one() {
459 assert_eq!(
460 initial_hessian_scalar(InitialApprox::Identity, 1.0, 1.0, 1.0, 1e-8, 1e8),
461 1.0
462 );
463 }
464
465 #[test]
466 fn scalar1_init_is_sy_over_ss() {
467 let v = initial_hessian_scalar(InitialApprox::Scalar1, 4.0, 2.0, 0.0, 1e-8, 1e8);
469 assert!((v - 0.5).abs() < 1e-15);
470 }
471
472 #[test]
473 fn scalar2_init_is_yy_over_sy() {
474 let v = initial_hessian_scalar(InitialApprox::Scalar2, 0.0, 2.0, 8.0, 1e-8, 1e8);
476 assert!((v - 4.0).abs() < 1e-15);
477 }
478
479 #[test]
480 fn init_clamped_to_max() {
481 let v = initial_hessian_scalar(InitialApprox::Scalar2, 0.0, 1e-20, 1.0, 1e-8, 1e8);
482 assert_eq!(v, 1e8);
483 }
484
485 #[test]
486 fn init_clamped_to_min() {
487 let v = initial_hessian_scalar(InitialApprox::Scalar2, 0.0, 1e20, 1.0, 1e-8, 1e8);
488 assert_eq!(v, 1e-8);
489 }
490
491 #[test]
492 fn powell_no_damping_when_curvature_ok() {
493 assert_eq!(powell_damping_theta(1.0, 1.0), 1.0);
495 }
496
497 #[test]
498 fn powell_damps_when_curvature_violated() {
499 let theta = powell_damping_theta(0.1, 1.0);
501 assert!((theta - 8.0 / 9.0).abs() < 1e-15);
502 }
503
504 #[test]
505 fn bfgs_skip_criterion() {
506 assert!(bfgs_curvature_pair_ok(1.0, 1.0, 1.0));
508 assert!(!bfgs_curvature_pair_ok(1e-10, 1.0, 1.0));
510 }
511
512 #[test]
513 fn sr1_skip_criterion_uses_absolute_value() {
514 assert!(sr1_denominator_ok(-1.0, 1.0, 1.0));
516 assert!(!sr1_denominator_ok(1e-10, 1.0, 1.0));
517 }
518
519 fn rcv(values: &[Number]) -> Rc<dyn Vector> {
520 let mut v = pounce_linalg::dense_vector::DenseVectorSpace::new(values.len() as i32)
521 .make_new_dense();
522 v.set(0.0);
523 v.values_mut().copy_from_slice(values);
524 Rc::new(v)
525 }
526
527 #[test]
528 fn ingest_pair_accepts_well_curved_pair() {
529 let mut updater = LimMemQuasiNewtonUpdater::new();
530 let accepted = updater.ingest_pair(rcv(&[1.0, 0.0]), rcv(&[1.0, 0.0]));
532 assert!(accepted);
533 assert_eq!(updater.history.len(), 1);
534 let pair = &updater.history[0];
535 assert!((pair.s_dot_y - 1.0).abs() < 1e-15);
536 assert!((pair.s_norm - 1.0).abs() < 1e-15);
537 assert!((pair.y_norm - 1.0).abs() < 1e-15);
538 }
539
540 #[test]
541 fn ingest_pair_skips_zero_curvature() {
542 let mut updater = LimMemQuasiNewtonUpdater::new();
543 let accepted = updater.ingest_pair(rcv(&[1.0]), rcv(&[0.0]));
545 assert!(!accepted);
546 assert!(updater.history.is_empty());
547 }
548
549 #[test]
550 fn history_caps_at_max_history() {
551 let mut updater = LimMemQuasiNewtonUpdater {
552 max_history: 2,
553 ..LimMemQuasiNewtonUpdater::default()
554 };
555 for _ in 0..5 {
556 updater.ingest_pair(rcv(&[1.0]), rcv(&[1.0]));
557 }
558 assert_eq!(updater.history.len(), 2);
559 }
560
561 #[test]
562 fn sr1_path_routes_through_sr1_skip() {
563 let mut updater = LimMemQuasiNewtonUpdater {
564 update_type: UpdateType::Sr1,
565 ..LimMemQuasiNewtonUpdater::default()
566 };
567 assert!(updater.ingest_pair(rcv(&[1.0]), rcv(&[-1.0])));
569 }
570
571 fn pair(s: &[Number], y: &[Number]) -> CurvaturePair {
572 let s_rc = rcv(s);
573 let y_rc = rcv(y);
574 let s_dot_y = s_rc.dot(&*y_rc);
575 let s_norm = s_rc.nrm2();
576 let y_norm = y_rc.nrm2();
577 CurvaturePair {
578 s: s_rc,
579 y: y_rc,
580 s_dot_y,
581 s_norm,
582 y_norm,
583 }
584 }
585
586 fn reconstruct_b(n: usize, sigma: Number, v: &[Vec<Number>], u: &[Vec<Number>]) -> Vec<Number> {
589 let mut b = vec![0.0_f64; n * n];
590 for i in 0..n {
591 b[i * n + i] = sigma;
592 }
593 for col in v {
594 for i in 0..n {
595 for j in 0..n {
596 b[i * n + j] += col[i] * col[j];
597 }
598 }
599 }
600 for col in u {
601 for i in 0..n {
602 for j in 0..n {
603 b[i * n + j] -= col[i] * col[j];
604 }
605 }
606 }
607 b
608 }
609
610 fn mat_vec(b: &[Number], n: usize, x: &[Number]) -> Vec<Number> {
611 (0..n)
612 .map(|i| (0..n).map(|j| b[i * n + j] * x[j]).sum())
613 .collect()
614 }
615
616 #[test]
617 fn bfgs_low_rank_recovers_hessian_action() {
618 let mut up = LimMemQuasiNewtonUpdater::new();
623 up.history.push(pair(&[1.0, 1.0], &[2.0, 5.0]));
624 let (v, u) = up.build_low_rank(1.0, 2);
625 let b = reconstruct_b(2, 1.0, &v, &u);
626 let bs = mat_vec(&b, 2, &[1.0, 1.0]);
627 assert!((bs[0] - 2.0).abs() < 1e-12, "Bs[0]={}", bs[0]);
628 assert!((bs[1] - 5.0).abs() < 1e-12, "Bs[1]={}", bs[1]);
629 }
630
631 #[test]
632 fn bfgs_low_rank_keeps_symmetry() {
633 let mut up = LimMemQuasiNewtonUpdater::new();
634 up.history.push(pair(&[1.0, 0.5], &[2.0, 1.0]));
635 up.history.push(pair(&[0.7, 1.2], &[1.0, 2.5]));
636 let (v, u) = up.build_low_rank(3.0, 2);
637 let b = reconstruct_b(2, 3.0, &v, &u);
638 assert!((b[1] - b[2]).abs() < 1e-12);
640 }
641
642 #[test]
643 fn sr1_low_rank_recovers_hessian_action() {
644 let mut up = LimMemQuasiNewtonUpdater {
648 update_type: UpdateType::Sr1,
649 ..LimMemQuasiNewtonUpdater::default()
650 };
651 up.history.push(pair(&[1.0, 1.0], &[2.0, 5.0]));
652 let (v, u) = up.build_low_rank(1.0, 2);
653 assert_eq!(v.len(), 1, "positive denom routes to V");
654 assert!(u.is_empty());
655 let b = reconstruct_b(2, 1.0, &v, &u);
656 let bs = mat_vec(&b, 2, &[1.0, 1.0]);
657 assert!((bs[0] - 2.0).abs() < 1e-12);
658 assert!((bs[1] - 5.0).abs() < 1e-12);
659 }
660
661 #[test]
662 fn empty_history_yields_no_columns() {
663 let up = LimMemQuasiNewtonUpdater::new();
664 let (v, u) = up.build_low_rank(1.0, 4);
665 assert!(v.is_empty() && u.is_empty());
666 }
667}