1use crate::kkt::aug_system_solver::{AugSysCoeffs, AugSysRhs, AugSysSol, AugSystemSolver};
27use pounce_common::tagged::Tag;
28use pounce_common::timing::TimingStatistics;
29use pounce_common::types::{Index, Number};
30use pounce_linalg::dense_gen_matrix::{DenseGenMatrix, DenseGenMatrixSpace};
31use pounce_linalg::dense_sym_matrix::DenseSymMatrixSpace;
32use pounce_linalg::dense_vector::{DenseVector, DenseVectorSpace};
33use pounce_linalg::diag_matrix::DiagMatrix;
34use pounce_linalg::low_rank_update_sym_matrix::LowRankUpdateSymMatrix;
35use pounce_linalg::multi_vector_matrix::{MultiVectorMatrix, MultiVectorMatrixSpace};
36use pounce_linalg::{Matrix, SymMatrix, Vector};
37use pounce_linsol::ESymSolverStatus;
38use std::rc::Rc;
39
40pub struct LowRankAugSystemSolver {
41 inner: Box<dyn AugSystemSolver>,
43 first_call: bool,
45 num_neg_evals: Index,
47 cache: AugSysCache,
49 factor: Factorization,
51}
52
53#[derive(Debug, Clone)]
54pub struct AugSysCache {
55 pub w_tag: Tag,
56 pub w_factor: Number,
57 pub d_x_tag: Tag,
58 pub delta_x: Number,
59 pub d_s_tag: Tag,
60 pub delta_s: Number,
61 pub j_c_tag: Tag,
62 pub d_c_tag: Tag,
63 pub delta_c: Number,
64 pub j_d_tag: Tag,
65 pub d_d_tag: Tag,
66 pub delta_d: Number,
67}
68
69impl Default for AugSysCache {
70 fn default() -> Self {
71 Self {
72 w_tag: Tag::NONE,
73 w_factor: 0.0,
74 d_x_tag: Tag::NONE,
75 delta_x: 0.0,
76 d_s_tag: Tag::NONE,
77 delta_s: 0.0,
78 j_c_tag: Tag::NONE,
79 d_c_tag: Tag::NONE,
80 delta_c: 0.0,
81 j_d_tag: Tag::NONE,
82 d_d_tag: Tag::NONE,
83 delta_d: 0.0,
84 }
85 }
86}
87
88#[derive(Default)]
89struct Factorization {
90 wdiag: Option<Box<DiagMatrix>>,
94 j1: Option<DenseGenMatrix>,
96 j2: Option<DenseGenMatrix>,
98 vtilde1_x: Option<MultiVectorMatrix>,
100 vtilde1_s: Option<MultiVectorMatrix>,
101 vtilde1_c: Option<MultiVectorMatrix>,
102 vtilde1_d: Option<MultiVectorMatrix>,
103 utilde2_x: Option<MultiVectorMatrix>,
105 utilde2_s: Option<MultiVectorMatrix>,
106 utilde2_c: Option<MultiVectorMatrix>,
107 utilde2_d: Option<MultiVectorMatrix>,
108}
109
110impl LowRankAugSystemSolver {
111 pub fn new(inner: Box<dyn AugSystemSolver>) -> Self {
112 Self {
113 inner,
114 first_call: true,
115 num_neg_evals: 0,
116 cache: AugSysCache::default(),
117 factor: Factorization::default(),
118 }
119 }
120
121 pub fn augmented_system_requires_change(&self, coeffs: &AugSysCoeffs<'_>) -> bool {
124 let cache = &self.cache;
125 let zero_tag: Tag = Tag::NONE;
126
127 let w_changed = match coeffs.w {
128 Some(w) => w.as_tagged().get_tag() != cache.w_tag,
129 None => cache.w_tag != zero_tag,
130 };
131 if w_changed || coeffs.w_factor != cache.w_factor {
132 return true;
133 }
134 let dx_changed = match coeffs.d_x {
135 Some(d) => d.as_tagged().get_tag() != cache.d_x_tag,
136 None => cache.d_x_tag != zero_tag,
137 };
138 if dx_changed || coeffs.delta_x != cache.delta_x {
139 return true;
140 }
141 let ds_changed = match coeffs.d_s {
142 Some(d) => d.as_tagged().get_tag() != cache.d_s_tag,
143 None => cache.d_s_tag != zero_tag,
144 };
145 if ds_changed || coeffs.delta_s != cache.delta_s {
146 return true;
147 }
148 if coeffs.j_c.as_tagged().get_tag() != cache.j_c_tag {
149 return true;
150 }
151 let dc_changed = match coeffs.d_c {
152 Some(d) => d.as_tagged().get_tag() != cache.d_c_tag,
153 None => cache.d_c_tag != zero_tag,
154 };
155 if dc_changed || coeffs.delta_c != cache.delta_c {
156 return true;
157 }
158 if coeffs.j_d.as_tagged().get_tag() != cache.j_d_tag {
159 return true;
160 }
161 let dd_changed = match coeffs.d_d {
162 Some(d) => d.as_tagged().get_tag() != cache.d_d_tag,
163 None => cache.d_d_tag != zero_tag,
164 };
165 if dd_changed || coeffs.delta_d != cache.delta_d {
166 return true;
167 }
168 false
169 }
170
171 fn store_cache(&mut self, coeffs: &AugSysCoeffs<'_>) {
172 let zero_tag = Tag::NONE;
173 self.cache.w_tag = coeffs
174 .w
175 .map(|w| w.as_tagged().get_tag())
176 .unwrap_or(zero_tag);
177 self.cache.w_factor = coeffs.w_factor;
178 self.cache.d_x_tag = coeffs
179 .d_x
180 .map(|d| d.as_tagged().get_tag())
181 .unwrap_or(zero_tag);
182 self.cache.delta_x = coeffs.delta_x;
183 self.cache.d_s_tag = coeffs
184 .d_s
185 .map(|d| d.as_tagged().get_tag())
186 .unwrap_or(zero_tag);
187 self.cache.delta_s = coeffs.delta_s;
188 self.cache.j_c_tag = coeffs.j_c.as_tagged().get_tag();
189 self.cache.d_c_tag = coeffs
190 .d_c
191 .map(|d| d.as_tagged().get_tag())
192 .unwrap_or(zero_tag);
193 self.cache.delta_c = coeffs.delta_c;
194 self.cache.j_d_tag = coeffs.j_d.as_tagged().get_tag();
195 self.cache.d_d_tag = coeffs
196 .d_d
197 .map(|d| d.as_tagged().get_tag())
198 .unwrap_or(zero_tag);
199 self.cache.delta_d = coeffs.delta_d;
200 }
201
202 pub fn first_call(&self) -> bool {
203 self.first_call
204 }
205
206 pub fn cache(&self) -> &AugSysCache {
207 &self.cache
208 }
209
210 fn update_factorization(
217 &mut self,
218 lr_w: &LowRankUpdateSymMatrix,
219 coeffs: &AugSysCoeffs<'_>,
220 proto: &AugSysRhs<'_>,
221 check_neg_evals: bool,
222 num_neg_evals: Index,
223 ) -> ESymSolverStatus {
224 let proto_x = downcast_dense(proto.rhs_x);
225 let proto_s = downcast_dense(proto.rhs_s);
226 let proto_c = downcast_dense(proto.rhs_c);
227 let proto_d = downcast_dense(proto.rhs_d);
228 let space_x = Rc::clone(proto_x.space());
229 let space_s = Rc::clone(proto_s.space());
230 let space_c = Rc::clone(proto_c.space());
231 let space_d = Rc::clone(proto_d.space());
232
233 let b0_dense: DenseVector = if coeffs.w_factor == 1.0 {
237 match lr_w.get_diag() {
238 Some(d) => clone_dense(downcast_dense(d.as_ref())),
239 None => zero_x_for(&space_x, lr_w),
240 }
241 } else {
242 zero_x_for(&space_x, lr_w)
243 };
244
245 let wdiag_diag: Rc<dyn Vector> = match (lr_w.p_lowrank(), lr_w.reduced_diag()) {
246 (Some(p_lm), true) => {
247 let mut fullx = space_x.make_new_dense();
249 p_lm.mult_vector(1.0, &b0_dense, 0.0, &mut fullx);
250 Rc::new(fullx) as Rc<dyn Vector>
251 }
252 _ => Rc::new(clone_dense(&b0_dense)) as Rc<dyn Vector>,
253 };
254 let mut wdiag = Box::new(DiagMatrix::new(space_x.dim()));
255 wdiag.set_diag(wdiag_diag);
256 self.factor.wdiag = Some(wdiag);
257
258 if coeffs.w_factor == 1.0 && lr_w.get_v().is_some() {
260 let v = Rc::clone(lr_w.get_v().unwrap());
261 let n_v = v.n_cols();
262
263 let v_x_space = MultiVectorMatrixSpace::new(n_v, Rc::clone(&space_x));
267 let mut v_x = v_x_space.make_new_multi_vector();
268 for k in 0..n_v {
269 let vk = Rc::clone(v.get_vector(k));
270 let rhs_x_k: Rc<dyn Vector> = match lr_w.p_lowrank() {
271 Some(p_lm) => {
272 let mut fullx = space_x.make_new_dense();
273 p_lm.mult_vector(1.0, vk.as_ref(), 0.0, &mut fullx);
274 Rc::new(fullx) as Rc<dyn Vector>
275 }
276 None => vk,
277 };
278 v_x.set_vector(k, rhs_x_k);
279 }
280
281 let (vt_x, vt_s, vt_c, vt_d) = self.multi_solve_block(
282 &v_x,
283 coeffs,
284 &space_x,
285 &space_s,
286 &space_c,
287 &space_d,
288 check_neg_evals,
289 num_neg_evals,
290 );
291
292 let vt_x = match vt_x {
293 Ok(x) => x,
294 Err(status) => return status,
295 };
296
297 let m1_space = DenseSymMatrixSpace::new(n_v);
299 let mut m1 = m1_space.make_new_dense_sym();
300 m1.fill_identity(1.0);
301 m1.high_rank_update_transpose(1.0, &vt_x, &v_x, 1.0);
302 let j1_space = DenseGenMatrixSpace::new(n_v, n_v);
303 let mut j1 = j1_space.make_new_dense_gen();
304 if !j1.compute_cholesky_factor(&m1) {
305 self.num_neg_evals += 1;
306 return ESymSolverStatus::WrongInertia;
307 }
308 self.factor.vtilde1_x = Some(vt_x);
309 self.factor.vtilde1_s = Some(vt_s);
310 self.factor.vtilde1_c = Some(vt_c);
311 self.factor.vtilde1_d = Some(vt_d);
312 self.factor.j1 = Some(j1);
313 } else {
314 self.factor.vtilde1_x = None;
315 self.factor.vtilde1_s = None;
316 self.factor.vtilde1_c = None;
317 self.factor.vtilde1_d = None;
318 self.factor.j1 = None;
319 }
320
321 if coeffs.w_factor == 1.0 && lr_w.get_u().is_some() {
324 let u = Rc::clone(lr_w.get_u().unwrap());
325 let n_u = u.n_cols();
326
327 let u_x_space = MultiVectorMatrixSpace::new(n_u, Rc::clone(&space_x));
328 let mut u_x = u_x_space.make_new_multi_vector();
329 for k in 0..n_u {
330 let uk = Rc::clone(u.get_vector(k));
331 let rhs_x_k: Rc<dyn Vector> = match lr_w.p_lowrank() {
332 Some(p_lm) => {
333 let mut fullx = space_x.make_new_dense();
334 p_lm.mult_vector(1.0, uk.as_ref(), 0.0, &mut fullx);
335 Rc::new(fullx) as Rc<dyn Vector>
336 }
337 None => uk,
338 };
339 u_x.set_vector(k, rhs_x_k);
340 }
341
342 let (mut ut_x, mut ut_s, mut ut_c, mut ut_d) = match self.multi_solve_block(
343 &u_x,
344 coeffs,
345 &space_x,
346 &space_s,
347 &space_c,
348 &space_d,
349 check_neg_evals,
350 num_neg_evals,
351 ) {
352 (Ok(x), s, c, d) => (x, s, c, d),
353 (Err(status), _, _, _) => return status,
354 };
355
356 if self.factor.vtilde1_x.is_some() {
358 let vt1_x = self.factor.vtilde1_x.as_ref().unwrap();
359 let vt1_s = self.factor.vtilde1_s.as_ref().unwrap();
360 let vt1_c = self.factor.vtilde1_c.as_ref().unwrap();
361 let vt1_d = self.factor.vtilde1_d.as_ref().unwrap();
362 let n_v = vt1_x.n_cols();
363 let c_space = DenseGenMatrixSpace::new(n_v, n_u);
367 let mut c_mat = c_space.make_new_dense_gen();
368 {
369 let cv = c_mat.values_mut();
370 for j in 0..n_u as usize {
371 let uj = u_x.get_vector(j as Index).as_ref();
372 for i in 0..n_v as usize {
373 let vi = vt1_x.get_vector(i as Index).as_ref();
374 cv[i + j * n_v as usize] = vi.dot(uj);
375 }
376 }
377 }
378 self.factor
379 .j1
380 .as_ref()
381 .unwrap()
382 .cholesky_solve_matrix(&mut c_mat);
383 ut_x.add_right_mult_matrix(-1.0, vt1_x, &c_mat, 1.0);
384 ut_s.add_right_mult_matrix(-1.0, vt1_s, &c_mat, 1.0);
385 ut_c.add_right_mult_matrix(-1.0, vt1_c, &c_mat, 1.0);
386 ut_d.add_right_mult_matrix(-1.0, vt1_d, &c_mat, 1.0);
387 }
388
389 let m2_space = DenseSymMatrixSpace::new(n_u);
394 let mut m2 = m2_space.make_new_dense_sym();
395 m2.fill_identity(1.0);
396 m2.high_rank_update_transpose(-1.0, &ut_x, &u_x, 1.0);
397 let j2_space = DenseGenMatrixSpace::new(n_u, n_u);
398 let mut j2 = j2_space.make_new_dense_gen();
399 if !j2.compute_cholesky_factor(&m2) {
400 self.num_neg_evals += 1;
401 return ESymSolverStatus::WrongInertia;
402 }
403 self.factor.utilde2_x = Some(ut_x);
404 self.factor.utilde2_s = Some(ut_s);
405 self.factor.utilde2_c = Some(ut_c);
406 self.factor.utilde2_d = Some(ut_d);
407 self.factor.j2 = Some(j2);
408 } else {
409 self.factor.utilde2_x = None;
410 self.factor.utilde2_s = None;
411 self.factor.utilde2_c = None;
412 self.factor.utilde2_d = None;
413 self.factor.j2 = None;
414 }
415
416 ESymSolverStatus::Success
417 }
418
419 #[allow(clippy::too_many_arguments)]
424 fn multi_solve_block(
425 &mut self,
426 v_x: &MultiVectorMatrix,
427 coeffs: &AugSysCoeffs<'_>,
428 space_x: &Rc<DenseVectorSpace>,
429 space_s: &Rc<DenseVectorSpace>,
430 space_c: &Rc<DenseVectorSpace>,
431 space_d: &Rc<DenseVectorSpace>,
432 check_neg_evals: bool,
433 num_neg_evals: Index,
434 ) -> (
435 Result<MultiVectorMatrix, ESymSolverStatus>,
436 MultiVectorMatrix,
437 MultiVectorMatrix,
438 MultiVectorMatrix,
439 ) {
440 let n_cols = v_x.n_cols();
441
442 let mut out_x =
444 MultiVectorMatrixSpace::new(n_cols, Rc::clone(space_x)).make_new_multi_vector();
445 let mut out_s =
446 MultiVectorMatrixSpace::new(n_cols, Rc::clone(space_s)).make_new_multi_vector();
447 let mut out_c =
448 MultiVectorMatrixSpace::new(n_cols, Rc::clone(space_c)).make_new_multi_vector();
449 let mut out_d =
450 MultiVectorMatrixSpace::new(n_cols, Rc::clone(space_d)).make_new_multi_vector();
451 out_x.fill_with_new_vectors();
452 out_s.fill_with_new_vectors();
453 out_c.fill_with_new_vectors();
454 out_d.fill_with_new_vectors();
455
456 let mut rhs_s = space_s.make_new_dense();
459 rhs_s.set(0.0);
460 let mut rhs_c = space_c.make_new_dense();
461 rhs_c.set(0.0);
462 let mut rhs_d = space_d.make_new_dense();
463 rhs_d.set(0.0);
464
465 for k in 0..n_cols {
466 let rhs_x_dyn: &dyn Vector = v_x.get_vector(k).as_ref();
467 let inner_rhs = AugSysRhs {
468 rhs_x: rhs_x_dyn,
469 rhs_s: rhs_s.as_dyn_vector(),
470 rhs_c: rhs_c.as_dyn_vector(),
471 rhs_d: rhs_d.as_dyn_vector(),
472 };
473 let mut sol_x = space_x.make_new_dense();
475 let mut sol_s = space_s.make_new_dense();
476 let mut sol_c = space_c.make_new_dense();
477 let mut sol_d = space_d.make_new_dense();
478 sol_x.set(0.0);
479 sol_s.set(0.0);
480 sol_c.set(0.0);
481 sol_d.set(0.0);
482 let inner_coeffs = inner_coeffs(&self.factor, coeffs);
483 let status = {
484 let mut sol = AugSysSol {
485 sol_x: &mut sol_x,
486 sol_s: &mut sol_s,
487 sol_c: &mut sol_c,
488 sol_d: &mut sol_d,
489 };
490 self.inner.solve(
491 &inner_coeffs,
492 &inner_rhs,
493 &mut sol,
494 check_neg_evals,
495 num_neg_evals,
496 )
497 };
498 if self.inner.provides_inertia() {
499 self.num_neg_evals = self.inner.number_of_neg_evals();
500 }
501 if status != ESymSolverStatus::Success {
502 return (Err(status), out_s, out_c, out_d);
503 }
504 out_x.set_vector(k, Rc::new(sol_x) as Rc<dyn Vector>);
505 out_s.set_vector(k, Rc::new(sol_s) as Rc<dyn Vector>);
506 out_c.set_vector(k, Rc::new(sol_c) as Rc<dyn Vector>);
507 out_d.set_vector(k, Rc::new(sol_d) as Rc<dyn Vector>);
508 }
509 (Ok(out_x), out_s, out_c, out_d)
510 }
511}
512
513fn inner_coeffs<'b>(factor: &'b Factorization, coeffs: &AugSysCoeffs<'b>) -> AugSysCoeffs<'b> {
518 let wdiag: &DiagMatrix = factor.wdiag.as_ref().expect("Wdiag unset").as_ref();
519 AugSysCoeffs {
520 w: Some(wdiag as &dyn SymMatrix),
521 w_factor: 1.0,
522 d_x: coeffs.d_x,
523 delta_x: coeffs.delta_x,
524 d_s: coeffs.d_s,
525 delta_s: coeffs.delta_s,
526 j_c: coeffs.j_c,
527 d_c: coeffs.d_c,
528 delta_c: coeffs.delta_c,
529 j_d: coeffs.j_d,
530 d_d: coeffs.d_d,
531 delta_d: coeffs.delta_d,
532 }
533}
534
535fn downcast_dense(v: &dyn Vector) -> &DenseVector {
536 v.as_any()
537 .downcast_ref::<DenseVector>()
538 .expect("LowRankAugSystemSolver currently requires DenseVector RHS/solutions")
539}
540
541fn clone_dense(src: &DenseVector) -> DenseVector {
545 let mut out = src.space().make_new_dense();
546 out.set_values(&src.expanded_values());
547 out
548}
549
550fn zero_x_for(space_x: &Rc<DenseVectorSpace>, lr_w: &LowRankUpdateSymMatrix) -> DenseVector {
551 let _ = lr_w;
556 let mut z = space_x.make_new_dense();
557 z.set(0.0);
558 z
559}
560
561impl AugSystemSolver for LowRankAugSystemSolver {
562 fn provides_inertia(&self) -> bool {
563 self.inner.provides_inertia()
564 }
565
566 fn number_of_neg_evals(&self) -> Index {
567 if self.inner.provides_inertia() {
568 self.inner.number_of_neg_evals()
569 } else {
570 self.num_neg_evals
571 }
572 }
573
574 fn increase_quality(&mut self) -> bool {
575 self.inner.increase_quality()
576 }
577
578 fn last_solve_status(&self) -> ESymSolverStatus {
579 self.inner.last_solve_status()
580 }
581
582 fn set_timing_stats(&mut self, timing: Rc<TimingStatistics>) {
583 self.inner.set_timing_stats(timing);
584 }
585
586 fn solve(
587 &mut self,
588 coeffs: &AugSysCoeffs<'_>,
589 rhs: &AugSysRhs<'_>,
590 sol: &mut AugSysSol<'_>,
591 check_neg_evals: bool,
592 num_neg_evals: Index,
593 ) -> ESymSolverStatus {
594 let mut check_neg_evals = check_neg_evals;
597 if !self.inner.provides_inertia() {
598 check_neg_evals = false;
599 }
600
601 let lr_w_opt = coeffs
610 .w
611 .and_then(|w| w.as_any().downcast_ref::<LowRankUpdateSymMatrix>());
612 let Some(lr_w) = lr_w_opt else {
613 let status = self
614 .inner
615 .solve(coeffs, rhs, sol, check_neg_evals, num_neg_evals);
616 if self.inner.provides_inertia() {
617 self.num_neg_evals = self.inner.number_of_neg_evals();
618 }
619 return status;
620 };
621
622 let needs_rebuild = self.first_call || self.augmented_system_requires_change(coeffs);
623 if needs_rebuild {
624 let status =
625 self.update_factorization(lr_w, coeffs, rhs, check_neg_evals, num_neg_evals);
626 if status != ESymSolverStatus::Success {
627 return status;
628 }
629 self.store_cache(coeffs);
630 self.first_call = false;
631 }
632
633 let ic = inner_coeffs(&self.factor, coeffs);
635 let status = self
636 .inner
637 .solve(&ic, rhs, sol, check_neg_evals, num_neg_evals);
638 if self.inner.provides_inertia() {
639 self.num_neg_evals = self.inner.number_of_neg_evals();
640 }
641 if status != ESymSolverStatus::Success {
642 return status;
643 }
644
645 if self.factor.utilde2_x.is_some() {
648 self.apply_smw(1.0, true, rhs, sol);
649 }
650 if self.factor.vtilde1_x.is_some() {
651 self.apply_smw(-1.0, false, rhs, sol);
652 }
653
654 ESymSolverStatus::Success
655 }
656}
657
658impl LowRankAugSystemSolver {
659 fn apply_smw(&self, sign: Number, use_u: bool, rhs: &AugSysRhs<'_>, sol: &mut AugSysSol<'_>) {
665 let (mvx, mvs, mvc, mvd, j) = if use_u {
666 (
667 self.factor.utilde2_x.as_ref().unwrap(),
668 self.factor.utilde2_s.as_ref().unwrap(),
669 self.factor.utilde2_c.as_ref().unwrap(),
670 self.factor.utilde2_d.as_ref().unwrap(),
671 self.factor.j2.as_ref().unwrap(),
672 )
673 } else {
674 (
675 self.factor.vtilde1_x.as_ref().unwrap(),
676 self.factor.vtilde1_s.as_ref().unwrap(),
677 self.factor.vtilde1_c.as_ref().unwrap(),
678 self.factor.vtilde1_d.as_ref().unwrap(),
679 self.factor.j1.as_ref().unwrap(),
680 )
681 };
682 let n = mvx.n_cols();
683 let mut b_vec: Vec<Number> = Vec::with_capacity(n as usize);
687 for k in 0..n {
688 let dot = mvx.get_vector(k).dot(rhs.rhs_x)
689 + mvs.get_vector(k).dot(rhs.rhs_s)
690 + mvc.get_vector(k).dot(rhs.rhs_c)
691 + mvd.get_vector(k).dot(rhs.rhs_d);
692 b_vec.push(dot);
693 }
694 let space_b = DenseVectorSpace::new(n);
695 let mut b = space_b.make_new_dense();
696 b.set_values(&b_vec);
697 j.cholesky_solve_vector(&mut b);
699 mvx.mult_vector(sign, &b, 1.0, sol.sol_x);
701 mvs.mult_vector(sign, &b, 1.0, sol.sol_s);
702 mvc.mult_vector(sign, &b, 1.0, sol.sol_c);
703 mvd.mult_vector(sign, &b, 1.0, sol.sol_d);
704 }
705}
706
707#[cfg(test)]
708mod tests {
709 use super::*;
710 use pounce_linalg::dense_vector::DenseVectorSpace;
711 use pounce_linalg::low_rank_update_sym_matrix::LowRankUpdateSymMatrixSpace;
712 use std::cell::Cell;
713
714 struct DiagInner {
719 calls: Cell<usize>,
720 }
721 impl AugSystemSolver for DiagInner {
722 fn provides_inertia(&self) -> bool {
723 true
724 }
725 fn number_of_neg_evals(&self) -> Index {
726 0
727 }
728 fn increase_quality(&mut self) -> bool {
729 true
730 }
731 fn last_solve_status(&self) -> ESymSolverStatus {
732 ESymSolverStatus::Success
733 }
734 fn solve(
735 &mut self,
736 coeffs: &AugSysCoeffs<'_>,
737 rhs: &AugSysRhs<'_>,
738 sol: &mut AugSysSol<'_>,
739 _check_neg_evals: bool,
740 _num_neg_evals: Index,
741 ) -> ESymSolverStatus {
742 self.calls.set(self.calls.get() + 1);
743 let wdiag = coeffs
744 .w
745 .expect("DiagInner requires W")
746 .as_any()
747 .downcast_ref::<DiagMatrix>()
748 .expect("DiagInner requires W to be a DiagMatrix");
749 let diag_rc = wdiag.get_diag().expect("Wdiag has no diag set").clone();
750 let diag = downcast_dense(diag_rc.as_ref()).expanded_values();
751 let rhs_x = downcast_dense(rhs.rhs_x).expanded_values();
752 let dx_vals: Option<Vec<Number>> =
753 coeffs.d_x.map(|d| downcast_dense(d).expanded_values());
754 let mut out = vec![0.0; rhs_x.len()];
755 for i in 0..rhs_x.len() {
756 let dx_i = match &dx_vals {
757 Some(v) => v[i],
758 None => 0.0,
759 };
760 let denom = diag[i] + dx_i + coeffs.delta_x;
761 out[i] = rhs_x[i] / denom;
762 }
763 let sol_x_dv = sol
764 .sol_x
765 .as_any_mut()
766 .downcast_mut::<DenseVector>()
767 .unwrap();
768 sol_x_dv.set_values(&out);
769 ESymSolverStatus::Success
771 }
772 }
773
774 fn dvec(space: &Rc<DenseVectorSpace>, vals: &[Number]) -> DenseVector {
775 let mut v = space.make_new_dense();
776 v.set_values(vals);
777 v
778 }
779
780 fn dvec_rc(space: &Rc<DenseVectorSpace>, vals: &[Number]) -> Rc<DenseVector> {
781 Rc::new(dvec(space, vals))
782 }
783
784 #[test]
785 fn smw_recovers_low_rank_inverse() {
786 let space_x = DenseVectorSpace::new(1);
791 let space_zero = DenseVectorSpace::new(0);
792 let lr_space = LowRankUpdateSymMatrixSpace::new(1, None, false);
793 let mut lr = lr_space.make_new_low_rank();
794 let b0_rc: Rc<dyn Vector> = dvec_rc(&space_x, &[2.0]);
795 lr.set_diag(b0_rc);
796 let v_space = MultiVectorMatrixSpace::new(1, Rc::clone(&space_x));
797 let mut v_mvm = v_space.make_new_multi_vector();
798 v_mvm.set_vector(0, dvec_rc(&space_x, &[3.0]) as Rc<dyn Vector>);
799 lr.set_v(Rc::new(v_mvm));
800 let lr_rc: Rc<LowRankUpdateSymMatrix> = Rc::new(lr);
801
802 let mut solver = LowRankAugSystemSolver::new(Box::new(DiagInner {
803 calls: Cell::new(0),
804 }));
805
806 let j_c_space = pounce_linalg::dense_gen_matrix::DenseGenMatrixSpace::new(0, 1);
808 let j_d_space = pounce_linalg::dense_gen_matrix::DenseGenMatrixSpace::new(0, 1);
809 let j_c = j_c_space.make_new_dense_gen();
810 let j_d = j_d_space.make_new_dense_gen();
811
812 let coeffs = AugSysCoeffs {
813 w: Some(lr_rc.as_ref() as &dyn SymMatrix),
814 w_factor: 1.0,
815 d_x: None,
816 delta_x: 0.0,
817 d_s: None,
818 delta_s: 0.0,
819 j_c: &j_c as &dyn Matrix,
820 d_c: None,
821 delta_c: 0.0,
822 j_d: &j_d as &dyn Matrix,
823 d_d: None,
824 delta_d: 0.0,
825 };
826
827 let rhs_x = dvec(&space_x, &[5.0]);
828 let rhs_s = dvec(&space_zero, &[]);
829 let rhs_c = dvec(&space_zero, &[]);
830 let rhs_d = dvec(&space_zero, &[]);
831 let rhs = AugSysRhs {
832 rhs_x: &rhs_x,
833 rhs_s: &rhs_s,
834 rhs_c: &rhs_c,
835 rhs_d: &rhs_d,
836 };
837 let mut sol_x = dvec(&space_x, &[0.0]);
838 let mut sol_s = dvec(&space_zero, &[]);
839 let mut sol_c = dvec(&space_zero, &[]);
840 let mut sol_d = dvec(&space_zero, &[]);
841 let mut sol = AugSysSol {
842 sol_x: &mut sol_x,
843 sol_s: &mut sol_s,
844 sol_c: &mut sol_c,
845 sol_d: &mut sol_d,
846 };
847 let status = solver.solve(&coeffs, &rhs, &mut sol, false, 0);
848 assert_eq!(status, ESymSolverStatus::Success);
849 let got = sol_x.expanded_values()[0];
851 let want = 5.0 / 11.0;
852 assert!((got - want).abs() < 1e-12, "got {} want {}", got, want);
853 }
854
855 #[test]
856 fn smw_with_u_only_applies_positive_correction() {
857 let space_x = DenseVectorSpace::new(1);
860 let space_zero = DenseVectorSpace::new(0);
861 let lr_space = LowRankUpdateSymMatrixSpace::new(1, None, false);
862 let mut lr = lr_space.make_new_low_rank();
863 lr.set_diag(dvec_rc(&space_x, &[5.0]));
864 let u_space = MultiVectorMatrixSpace::new(1, Rc::clone(&space_x));
865 let mut u_mvm = u_space.make_new_multi_vector();
866 u_mvm.set_vector(0, dvec_rc(&space_x, &[1.5]) as Rc<dyn Vector>);
867 lr.set_u(Rc::new(u_mvm));
868 let lr_rc: Rc<LowRankUpdateSymMatrix> = Rc::new(lr);
869
870 let mut solver = LowRankAugSystemSolver::new(Box::new(DiagInner {
871 calls: Cell::new(0),
872 }));
873
874 let j_c_space = pounce_linalg::dense_gen_matrix::DenseGenMatrixSpace::new(0, 1);
875 let j_d_space = pounce_linalg::dense_gen_matrix::DenseGenMatrixSpace::new(0, 1);
876 let j_c = j_c_space.make_new_dense_gen();
877 let j_d = j_d_space.make_new_dense_gen();
878
879 let coeffs = AugSysCoeffs {
880 w: Some(lr_rc.as_ref() as &dyn SymMatrix),
881 w_factor: 1.0,
882 d_x: None,
883 delta_x: 0.0,
884 d_s: None,
885 delta_s: 0.0,
886 j_c: &j_c as &dyn Matrix,
887 d_c: None,
888 delta_c: 0.0,
889 j_d: &j_d as &dyn Matrix,
890 d_d: None,
891 delta_d: 0.0,
892 };
893
894 let rhs_x = dvec(&space_x, &[7.0]);
895 let rhs_s = dvec(&space_zero, &[]);
896 let rhs_c = dvec(&space_zero, &[]);
897 let rhs_d = dvec(&space_zero, &[]);
898 let rhs = AugSysRhs {
899 rhs_x: &rhs_x,
900 rhs_s: &rhs_s,
901 rhs_c: &rhs_c,
902 rhs_d: &rhs_d,
903 };
904 let mut sol_x = dvec(&space_x, &[0.0]);
905 let mut sol_s = dvec(&space_zero, &[]);
906 let mut sol_c = dvec(&space_zero, &[]);
907 let mut sol_d = dvec(&space_zero, &[]);
908 let mut sol = AugSysSol {
909 sol_x: &mut sol_x,
910 sol_s: &mut sol_s,
911 sol_c: &mut sol_c,
912 sol_d: &mut sol_d,
913 };
914 let status = solver.solve(&coeffs, &rhs, &mut sol, false, 0);
915 assert_eq!(status, ESymSolverStatus::Success);
916 let got = sol_x.expanded_values()[0];
918 let want = 7.0 / 2.75;
919 assert!((got - want).abs() < 1e-12, "got {} want {}", got, want);
920 }
921
922 #[test]
923 fn smw_reports_wrong_inertia_on_indefinite_negative_update() {
924 let space_x = DenseVectorSpace::new(1);
934 let space_zero = DenseVectorSpace::new(0);
935 let lr_space = LowRankUpdateSymMatrixSpace::new(1, None, false);
936 let mut lr = lr_space.make_new_low_rank();
937 lr.set_diag(dvec_rc(&space_x, &[2.0]));
938 let u_space = MultiVectorMatrixSpace::new(1, Rc::clone(&space_x));
939 let mut u_mvm = u_space.make_new_multi_vector();
940 u_mvm.set_vector(0, dvec_rc(&space_x, &[2.0]) as Rc<dyn Vector>);
941 lr.set_u(Rc::new(u_mvm));
942 let lr_rc: Rc<LowRankUpdateSymMatrix> = Rc::new(lr);
943
944 let mut solver = LowRankAugSystemSolver::new(Box::new(DiagInner {
945 calls: Cell::new(0),
946 }));
947
948 let j_c_space = pounce_linalg::dense_gen_matrix::DenseGenMatrixSpace::new(0, 1);
949 let j_d_space = pounce_linalg::dense_gen_matrix::DenseGenMatrixSpace::new(0, 1);
950 let j_c = j_c_space.make_new_dense_gen();
951 let j_d = j_d_space.make_new_dense_gen();
952
953 let coeffs = AugSysCoeffs {
954 w: Some(lr_rc.as_ref() as &dyn SymMatrix),
955 w_factor: 1.0,
956 d_x: None,
957 delta_x: 0.0,
958 d_s: None,
959 delta_s: 0.0,
960 j_c: &j_c as &dyn Matrix,
961 d_c: None,
962 delta_c: 0.0,
963 j_d: &j_d as &dyn Matrix,
964 d_d: None,
965 delta_d: 0.0,
966 };
967
968 let rhs_x = dvec(&space_x, &[1.0]);
969 let rhs_s = dvec(&space_zero, &[]);
970 let rhs_c = dvec(&space_zero, &[]);
971 let rhs_d = dvec(&space_zero, &[]);
972 let rhs = AugSysRhs {
973 rhs_x: &rhs_x,
974 rhs_s: &rhs_s,
975 rhs_c: &rhs_c,
976 rhs_d: &rhs_d,
977 };
978 let mut sol_x = dvec(&space_x, &[0.0]);
979 let mut sol_s = dvec(&space_zero, &[]);
980 let mut sol_c = dvec(&space_zero, &[]);
981 let mut sol_d = dvec(&space_zero, &[]);
982 let mut sol = AugSysSol {
983 sol_x: &mut sol_x,
984 sol_s: &mut sol_s,
985 sol_c: &mut sol_c,
986 sol_d: &mut sol_d,
987 };
988 let status = solver.solve(&coeffs, &rhs, &mut sol, false, 0);
989 assert_eq!(status, ESymSolverStatus::WrongInertia);
990 }
991
992 #[test]
993 fn smw_with_v_and_u_combines_corrections() {
994 let space_x = DenseVectorSpace::new(1);
997 let space_zero = DenseVectorSpace::new(0);
998 let lr_space = LowRankUpdateSymMatrixSpace::new(1, None, false);
999 let mut lr = lr_space.make_new_low_rank();
1000 lr.set_diag(dvec_rc(&space_x, &[10.0]));
1001 let v_space = MultiVectorMatrixSpace::new(1, Rc::clone(&space_x));
1002 let mut v_mvm = v_space.make_new_multi_vector();
1003 v_mvm.set_vector(0, dvec_rc(&space_x, &[2.0]) as Rc<dyn Vector>);
1004 lr.set_v(Rc::new(v_mvm));
1005 let u_space = MultiVectorMatrixSpace::new(1, Rc::clone(&space_x));
1006 let mut u_mvm = u_space.make_new_multi_vector();
1007 u_mvm.set_vector(0, dvec_rc(&space_x, &[1.0]) as Rc<dyn Vector>);
1008 lr.set_u(Rc::new(u_mvm));
1009 let lr_rc: Rc<LowRankUpdateSymMatrix> = Rc::new(lr);
1010
1011 let mut solver = LowRankAugSystemSolver::new(Box::new(DiagInner {
1012 calls: Cell::new(0),
1013 }));
1014
1015 let j_c_space = pounce_linalg::dense_gen_matrix::DenseGenMatrixSpace::new(0, 1);
1016 let j_d_space = pounce_linalg::dense_gen_matrix::DenseGenMatrixSpace::new(0, 1);
1017 let j_c = j_c_space.make_new_dense_gen();
1018 let j_d = j_d_space.make_new_dense_gen();
1019
1020 let coeffs = AugSysCoeffs {
1021 w: Some(lr_rc.as_ref() as &dyn SymMatrix),
1022 w_factor: 1.0,
1023 d_x: None,
1024 delta_x: 0.0,
1025 d_s: None,
1026 delta_s: 0.0,
1027 j_c: &j_c as &dyn Matrix,
1028 d_c: None,
1029 delta_c: 0.0,
1030 j_d: &j_d as &dyn Matrix,
1031 d_d: None,
1032 delta_d: 0.0,
1033 };
1034
1035 let rhs_x = dvec(&space_x, &[1.0]);
1036 let rhs_s = dvec(&space_zero, &[]);
1037 let rhs_c = dvec(&space_zero, &[]);
1038 let rhs_d = dvec(&space_zero, &[]);
1039 let rhs = AugSysRhs {
1040 rhs_x: &rhs_x,
1041 rhs_s: &rhs_s,
1042 rhs_c: &rhs_c,
1043 rhs_d: &rhs_d,
1044 };
1045 let mut sol_x = dvec(&space_x, &[0.0]);
1046 let mut sol_s = dvec(&space_zero, &[]);
1047 let mut sol_c = dvec(&space_zero, &[]);
1048 let mut sol_d = dvec(&space_zero, &[]);
1049 let mut sol = AugSysSol {
1050 sol_x: &mut sol_x,
1051 sol_s: &mut sol_s,
1052 sol_c: &mut sol_c,
1053 sol_d: &mut sol_d,
1054 };
1055 let status = solver.solve(&coeffs, &rhs, &mut sol, false, 0);
1056 assert_eq!(status, ESymSolverStatus::Success);
1057 let got = sol_x.expanded_values()[0];
1059 let want = 1.0 / 13.0;
1060 assert!((got - want).abs() < 1e-12, "got {} want {}", got, want);
1061 }
1062
1063 #[test]
1064 fn unchanged_coeffs_skip_rebuild_after_first_call() {
1065 let mut lr_solver = LowRankAugSystemSolver::new(Box::new(DiagInner {
1066 calls: Cell::new(0),
1067 }));
1068 let space_x = DenseVectorSpace::new(1);
1069 let space_zero = DenseVectorSpace::new(0);
1070 let lr_space = LowRankUpdateSymMatrixSpace::new(1, None, false);
1071 let mut lr = lr_space.make_new_low_rank();
1072 lr.set_diag(dvec_rc(&space_x, &[2.0]));
1073 let lr_rc: Rc<LowRankUpdateSymMatrix> = Rc::new(lr);
1074 let j_c_space = pounce_linalg::dense_gen_matrix::DenseGenMatrixSpace::new(0, 1);
1075 let j_d_space = pounce_linalg::dense_gen_matrix::DenseGenMatrixSpace::new(0, 1);
1076 let j_c = j_c_space.make_new_dense_gen();
1077 let j_d = j_d_space.make_new_dense_gen();
1078 let coeffs = AugSysCoeffs {
1079 w: Some(lr_rc.as_ref() as &dyn SymMatrix),
1080 w_factor: 1.0,
1081 d_x: None,
1082 delta_x: 0.001,
1083 d_s: None,
1084 delta_s: 0.0,
1085 j_c: &j_c as &dyn Matrix,
1086 d_c: None,
1087 delta_c: 0.0,
1088 j_d: &j_d as &dyn Matrix,
1089 d_d: None,
1090 delta_d: 0.0,
1091 };
1092 let rhs_x = dvec(&space_x, &[1.0]);
1093 let rhs_zero = dvec(&space_zero, &[]);
1094 let rhs = AugSysRhs {
1095 rhs_x: &rhs_x,
1096 rhs_s: &rhs_zero,
1097 rhs_c: &rhs_zero,
1098 rhs_d: &rhs_zero,
1099 };
1100 let mut sol_x = dvec(&space_x, &[0.0]);
1101 let mut sol_z1 = dvec(&space_zero, &[]);
1102 let mut sol_z2 = dvec(&space_zero, &[]);
1103 let mut sol_z3 = dvec(&space_zero, &[]);
1104 {
1105 let mut sol = AugSysSol {
1106 sol_x: &mut sol_x,
1107 sol_s: &mut sol_z1,
1108 sol_c: &mut sol_z2,
1109 sol_d: &mut sol_z3,
1110 };
1111 lr_solver.solve(&coeffs, &rhs, &mut sol, false, 0);
1112 }
1113 assert!(!lr_solver.augmented_system_requires_change(&coeffs));
1115 }
1116}