1use crate::kkt::aug_system_solver::{AugSysCoeffs, AugSysRhs, AugSysSol, AugSystemSolver};
27use pounce_common::tagged::Tag;
28use pounce_common::timing::TimingStatistics;
29use pounce_common::types::{Index, Number};
30use pounce_linalg::dense_gen_matrix::{DenseGenMatrix, DenseGenMatrixSpace};
31use pounce_linalg::dense_sym_matrix::DenseSymMatrixSpace;
32use pounce_linalg::dense_vector::{DenseVector, DenseVectorSpace};
33use pounce_linalg::diag_matrix::DiagMatrix;
34use pounce_linalg::low_rank_update_sym_matrix::LowRankUpdateSymMatrix;
35use pounce_linalg::multi_vector_matrix::{MultiVectorMatrix, MultiVectorMatrixSpace};
36use pounce_linalg::{Matrix, SymMatrix, Vector};
37use pounce_linsol::ESymSolverStatus;
38use std::rc::Rc;
39
40pub struct LowRankAugSystemSolver {
41 inner: Box<dyn AugSystemSolver>,
43 first_call: bool,
45 num_neg_evals: Index,
47 cache: AugSysCache,
49 factor: Factorization,
51}
52
53#[derive(Debug, Clone)]
54pub struct AugSysCache {
55 pub w_tag: Tag,
56 pub w_factor: Number,
57 pub d_x_tag: Tag,
58 pub delta_x: Number,
59 pub d_s_tag: Tag,
60 pub delta_s: Number,
61 pub j_c_tag: Tag,
62 pub d_c_tag: Tag,
63 pub delta_c: Number,
64 pub j_d_tag: Tag,
65 pub d_d_tag: Tag,
66 pub delta_d: Number,
67}
68
69impl Default for AugSysCache {
70 fn default() -> Self {
71 Self {
72 w_tag: Tag::NONE,
73 w_factor: 0.0,
74 d_x_tag: Tag::NONE,
75 delta_x: 0.0,
76 d_s_tag: Tag::NONE,
77 delta_s: 0.0,
78 j_c_tag: Tag::NONE,
79 d_c_tag: Tag::NONE,
80 delta_c: 0.0,
81 j_d_tag: Tag::NONE,
82 d_d_tag: Tag::NONE,
83 delta_d: 0.0,
84 }
85 }
86}
87
88#[derive(Default)]
89struct Factorization {
90 wdiag: Option<Box<DiagMatrix>>,
94 j1: Option<DenseGenMatrix>,
96 j2: Option<DenseGenMatrix>,
98 vtilde1_x: Option<MultiVectorMatrix>,
100 vtilde1_s: Option<MultiVectorMatrix>,
101 vtilde1_c: Option<MultiVectorMatrix>,
102 vtilde1_d: Option<MultiVectorMatrix>,
103 utilde2_x: Option<MultiVectorMatrix>,
105 utilde2_s: Option<MultiVectorMatrix>,
106 utilde2_c: Option<MultiVectorMatrix>,
107 utilde2_d: Option<MultiVectorMatrix>,
108}
109
110impl LowRankAugSystemSolver {
111 pub fn new(inner: Box<dyn AugSystemSolver>) -> Self {
112 Self {
113 inner,
114 first_call: true,
115 num_neg_evals: 0,
116 cache: AugSysCache::default(),
117 factor: Factorization::default(),
118 }
119 }
120
121 pub fn augmented_system_requires_change(&self, coeffs: &AugSysCoeffs<'_>) -> bool {
124 let cache = &self.cache;
125 let zero_tag: Tag = Tag::NONE;
126
127 let w_changed = match coeffs.w {
128 Some(w) => w.as_tagged().get_tag() != cache.w_tag,
129 None => cache.w_tag != zero_tag,
130 };
131 if w_changed || coeffs.w_factor != cache.w_factor {
132 return true;
133 }
134 let dx_changed = match coeffs.d_x {
135 Some(d) => d.as_tagged().get_tag() != cache.d_x_tag,
136 None => cache.d_x_tag != zero_tag,
137 };
138 if dx_changed || coeffs.delta_x != cache.delta_x {
139 return true;
140 }
141 let ds_changed = match coeffs.d_s {
142 Some(d) => d.as_tagged().get_tag() != cache.d_s_tag,
143 None => cache.d_s_tag != zero_tag,
144 };
145 if ds_changed || coeffs.delta_s != cache.delta_s {
146 return true;
147 }
148 if coeffs.j_c.as_tagged().get_tag() != cache.j_c_tag {
149 return true;
150 }
151 let dc_changed = match coeffs.d_c {
152 Some(d) => d.as_tagged().get_tag() != cache.d_c_tag,
153 None => cache.d_c_tag != zero_tag,
154 };
155 if dc_changed || coeffs.delta_c != cache.delta_c {
156 return true;
157 }
158 if coeffs.j_d.as_tagged().get_tag() != cache.j_d_tag {
159 return true;
160 }
161 let dd_changed = match coeffs.d_d {
162 Some(d) => d.as_tagged().get_tag() != cache.d_d_tag,
163 None => cache.d_d_tag != zero_tag,
164 };
165 if dd_changed || coeffs.delta_d != cache.delta_d {
166 return true;
167 }
168 false
169 }
170
171 fn store_cache(&mut self, coeffs: &AugSysCoeffs<'_>) {
172 let zero_tag = Tag::NONE;
173 self.cache.w_tag = coeffs
174 .w
175 .map(|w| w.as_tagged().get_tag())
176 .unwrap_or(zero_tag);
177 self.cache.w_factor = coeffs.w_factor;
178 self.cache.d_x_tag = coeffs
179 .d_x
180 .map(|d| d.as_tagged().get_tag())
181 .unwrap_or(zero_tag);
182 self.cache.delta_x = coeffs.delta_x;
183 self.cache.d_s_tag = coeffs
184 .d_s
185 .map(|d| d.as_tagged().get_tag())
186 .unwrap_or(zero_tag);
187 self.cache.delta_s = coeffs.delta_s;
188 self.cache.j_c_tag = coeffs.j_c.as_tagged().get_tag();
189 self.cache.d_c_tag = coeffs
190 .d_c
191 .map(|d| d.as_tagged().get_tag())
192 .unwrap_or(zero_tag);
193 self.cache.delta_c = coeffs.delta_c;
194 self.cache.j_d_tag = coeffs.j_d.as_tagged().get_tag();
195 self.cache.d_d_tag = coeffs
196 .d_d
197 .map(|d| d.as_tagged().get_tag())
198 .unwrap_or(zero_tag);
199 self.cache.delta_d = coeffs.delta_d;
200 }
201
202 pub fn first_call(&self) -> bool {
203 self.first_call
204 }
205
206 pub fn cache(&self) -> &AugSysCache {
207 &self.cache
208 }
209
210 fn update_factorization(
217 &mut self,
218 lr_w: &LowRankUpdateSymMatrix,
219 coeffs: &AugSysCoeffs<'_>,
220 proto: &AugSysRhs<'_>,
221 check_neg_evals: bool,
222 num_neg_evals: Index,
223 ) -> ESymSolverStatus {
224 let proto_x = downcast_dense(proto.rhs_x);
225 let proto_s = downcast_dense(proto.rhs_s);
226 let proto_c = downcast_dense(proto.rhs_c);
227 let proto_d = downcast_dense(proto.rhs_d);
228 let space_x = Rc::clone(proto_x.space());
229 let space_s = Rc::clone(proto_s.space());
230 let space_c = Rc::clone(proto_c.space());
231 let space_d = Rc::clone(proto_d.space());
232
233 let b0_dense: DenseVector = if coeffs.w_factor == 1.0 {
237 match lr_w.get_diag() {
238 Some(d) => clone_dense(downcast_dense(d.as_ref())),
239 None => zero_x_for(&space_x, lr_w),
240 }
241 } else {
242 zero_x_for(&space_x, lr_w)
243 };
244
245 let wdiag_diag: Rc<dyn Vector> = match (lr_w.p_lowrank(), lr_w.reduced_diag()) {
246 (Some(p_lm), true) => {
247 let mut fullx = space_x.make_new_dense();
249 p_lm.mult_vector(1.0, &b0_dense, 0.0, &mut fullx);
250 Rc::new(fullx) as Rc<dyn Vector>
251 }
252 _ => Rc::new(clone_dense(&b0_dense)) as Rc<dyn Vector>,
253 };
254 let mut wdiag = Box::new(DiagMatrix::new(space_x.dim()));
255 wdiag.set_diag(wdiag_diag);
256 self.factor.wdiag = Some(wdiag);
257
258 if coeffs.w_factor == 1.0 && lr_w.get_v().is_some() {
260 let v = Rc::clone(lr_w.get_v().unwrap());
261 let n_v = v.n_cols();
262
263 let v_x_space = MultiVectorMatrixSpace::new(n_v, Rc::clone(&space_x));
267 let mut v_x = v_x_space.make_new_multi_vector();
268 for k in 0..n_v {
269 let vk = Rc::clone(v.get_vector(k));
270 let rhs_x_k: Rc<dyn Vector> = match lr_w.p_lowrank() {
271 Some(p_lm) => {
272 let mut fullx = space_x.make_new_dense();
273 p_lm.mult_vector(1.0, vk.as_ref(), 0.0, &mut fullx);
274 Rc::new(fullx) as Rc<dyn Vector>
275 }
276 None => vk,
277 };
278 v_x.set_vector(k, rhs_x_k);
279 }
280
281 let (vt_x, vt_s, vt_c, vt_d) = self.multi_solve_block(
282 &v_x,
283 coeffs,
284 &space_x,
285 &space_s,
286 &space_c,
287 &space_d,
288 check_neg_evals,
289 num_neg_evals,
290 );
291
292 let vt_x = match vt_x {
293 Ok(x) => x,
294 Err(status) => return status,
295 };
296
297 let m1_space = DenseSymMatrixSpace::new(n_v);
299 let mut m1 = m1_space.make_new_dense_sym();
300 m1.fill_identity(1.0);
301 m1.high_rank_update_transpose(1.0, &vt_x, &v_x, 1.0);
302 let j1_space = DenseGenMatrixSpace::new(n_v, n_v);
303 let mut j1 = j1_space.make_new_dense_gen();
304 if !j1.compute_cholesky_factor(&m1) {
305 self.num_neg_evals += 1;
306 return ESymSolverStatus::WrongInertia;
307 }
308 self.factor.vtilde1_x = Some(vt_x);
309 self.factor.vtilde1_s = Some(vt_s);
310 self.factor.vtilde1_c = Some(vt_c);
311 self.factor.vtilde1_d = Some(vt_d);
312 self.factor.j1 = Some(j1);
313 } else {
314 self.factor.vtilde1_x = None;
315 self.factor.vtilde1_s = None;
316 self.factor.vtilde1_c = None;
317 self.factor.vtilde1_d = None;
318 self.factor.j1 = None;
319 }
320
321 if coeffs.w_factor == 1.0 && lr_w.get_u().is_some() {
324 let u = Rc::clone(lr_w.get_u().unwrap());
325 let n_u = u.n_cols();
326
327 let u_x_space = MultiVectorMatrixSpace::new(n_u, Rc::clone(&space_x));
328 let mut u_x = u_x_space.make_new_multi_vector();
329 for k in 0..n_u {
330 let uk = Rc::clone(u.get_vector(k));
331 let rhs_x_k: Rc<dyn Vector> = match lr_w.p_lowrank() {
332 Some(p_lm) => {
333 let mut fullx = space_x.make_new_dense();
334 p_lm.mult_vector(1.0, uk.as_ref(), 0.0, &mut fullx);
335 Rc::new(fullx) as Rc<dyn Vector>
336 }
337 None => uk,
338 };
339 u_x.set_vector(k, rhs_x_k);
340 }
341
342 let (mut ut_x, mut ut_s, mut ut_c, mut ut_d) = match self.multi_solve_block(
343 &u_x,
344 coeffs,
345 &space_x,
346 &space_s,
347 &space_c,
348 &space_d,
349 check_neg_evals,
350 num_neg_evals,
351 ) {
352 (Ok(x), s, c, d) => (x, s, c, d),
353 (Err(status), _, _, _) => return status,
354 };
355
356 if self.factor.vtilde1_x.is_some() {
358 let vt1_x = self.factor.vtilde1_x.as_ref().unwrap();
359 let vt1_s = self.factor.vtilde1_s.as_ref().unwrap();
360 let vt1_c = self.factor.vtilde1_c.as_ref().unwrap();
361 let vt1_d = self.factor.vtilde1_d.as_ref().unwrap();
362 let n_v = vt1_x.n_cols();
363 let c_space = DenseGenMatrixSpace::new(n_v, n_u);
367 let mut c_mat = c_space.make_new_dense_gen();
368 {
369 let cv = c_mat.values_mut();
370 for j in 0..n_u as usize {
371 let uj = u_x.get_vector(j as Index).as_ref();
372 for i in 0..n_v as usize {
373 let vi = vt1_x.get_vector(i as Index).as_ref();
374 cv[i + j * n_v as usize] = vi.dot(uj);
375 }
376 }
377 }
378 self.factor
379 .j1
380 .as_ref()
381 .unwrap()
382 .cholesky_solve_matrix(&mut c_mat);
383 ut_x.add_right_mult_matrix(-1.0, vt1_x, &c_mat, 1.0);
384 ut_s.add_right_mult_matrix(-1.0, vt1_s, &c_mat, 1.0);
385 ut_c.add_right_mult_matrix(-1.0, vt1_c, &c_mat, 1.0);
386 ut_d.add_right_mult_matrix(-1.0, vt1_d, &c_mat, 1.0);
387 }
388
389 let m2_space = DenseSymMatrixSpace::new(n_u);
391 let mut m2 = m2_space.make_new_dense_sym();
392 m2.fill_identity(1.0);
393 m2.high_rank_update_transpose(-1.0, &ut_x, &u_x, 1.0);
394 let j2_space = DenseGenMatrixSpace::new(n_u, n_u);
395 let mut j2 = j2_space.make_new_dense_gen();
396 if !j2.compute_cholesky_factor(&m2) {
397 self.num_neg_evals += 1;
398 return ESymSolverStatus::WrongInertia;
399 }
400 self.factor.utilde2_x = Some(ut_x);
401 self.factor.utilde2_s = Some(ut_s);
402 self.factor.utilde2_c = Some(ut_c);
403 self.factor.utilde2_d = Some(ut_d);
404 self.factor.j2 = Some(j2);
405 } else {
406 self.factor.utilde2_x = None;
407 self.factor.utilde2_s = None;
408 self.factor.utilde2_c = None;
409 self.factor.utilde2_d = None;
410 self.factor.j2 = None;
411 }
412
413 ESymSolverStatus::Success
414 }
415
416 #[allow(clippy::too_many_arguments)]
421 fn multi_solve_block(
422 &mut self,
423 v_x: &MultiVectorMatrix,
424 coeffs: &AugSysCoeffs<'_>,
425 space_x: &Rc<DenseVectorSpace>,
426 space_s: &Rc<DenseVectorSpace>,
427 space_c: &Rc<DenseVectorSpace>,
428 space_d: &Rc<DenseVectorSpace>,
429 check_neg_evals: bool,
430 num_neg_evals: Index,
431 ) -> (
432 Result<MultiVectorMatrix, ESymSolverStatus>,
433 MultiVectorMatrix,
434 MultiVectorMatrix,
435 MultiVectorMatrix,
436 ) {
437 let n_cols = v_x.n_cols();
438
439 let mut out_x =
441 MultiVectorMatrixSpace::new(n_cols, Rc::clone(space_x)).make_new_multi_vector();
442 let mut out_s =
443 MultiVectorMatrixSpace::new(n_cols, Rc::clone(space_s)).make_new_multi_vector();
444 let mut out_c =
445 MultiVectorMatrixSpace::new(n_cols, Rc::clone(space_c)).make_new_multi_vector();
446 let mut out_d =
447 MultiVectorMatrixSpace::new(n_cols, Rc::clone(space_d)).make_new_multi_vector();
448 out_x.fill_with_new_vectors();
449 out_s.fill_with_new_vectors();
450 out_c.fill_with_new_vectors();
451 out_d.fill_with_new_vectors();
452
453 let mut rhs_s = space_s.make_new_dense();
456 rhs_s.set(0.0);
457 let mut rhs_c = space_c.make_new_dense();
458 rhs_c.set(0.0);
459 let mut rhs_d = space_d.make_new_dense();
460 rhs_d.set(0.0);
461
462 for k in 0..n_cols {
463 let rhs_x_dyn: &dyn Vector = v_x.get_vector(k).as_ref();
464 let inner_rhs = AugSysRhs {
465 rhs_x: rhs_x_dyn,
466 rhs_s: rhs_s.as_dyn_vector(),
467 rhs_c: rhs_c.as_dyn_vector(),
468 rhs_d: rhs_d.as_dyn_vector(),
469 };
470 let mut sol_x = space_x.make_new_dense();
472 let mut sol_s = space_s.make_new_dense();
473 let mut sol_c = space_c.make_new_dense();
474 let mut sol_d = space_d.make_new_dense();
475 sol_x.set(0.0);
476 sol_s.set(0.0);
477 sol_c.set(0.0);
478 sol_d.set(0.0);
479 let inner_coeffs = inner_coeffs(&self.factor, coeffs);
480 let status = {
481 let mut sol = AugSysSol {
482 sol_x: &mut sol_x,
483 sol_s: &mut sol_s,
484 sol_c: &mut sol_c,
485 sol_d: &mut sol_d,
486 };
487 self.inner.solve(
488 &inner_coeffs,
489 &inner_rhs,
490 &mut sol,
491 check_neg_evals,
492 num_neg_evals,
493 )
494 };
495 if self.inner.provides_inertia() {
496 self.num_neg_evals = self.inner.number_of_neg_evals();
497 }
498 if status != ESymSolverStatus::Success {
499 return (Err(status), out_s, out_c, out_d);
500 }
501 out_x.set_vector(k, Rc::new(sol_x) as Rc<dyn Vector>);
502 out_s.set_vector(k, Rc::new(sol_s) as Rc<dyn Vector>);
503 out_c.set_vector(k, Rc::new(sol_c) as Rc<dyn Vector>);
504 out_d.set_vector(k, Rc::new(sol_d) as Rc<dyn Vector>);
505 }
506 (Ok(out_x), out_s, out_c, out_d)
507 }
508}
509
510fn inner_coeffs<'b>(factor: &'b Factorization, coeffs: &AugSysCoeffs<'b>) -> AugSysCoeffs<'b> {
515 let wdiag: &DiagMatrix = factor.wdiag.as_ref().expect("Wdiag unset").as_ref();
516 AugSysCoeffs {
517 w: Some(wdiag as &dyn SymMatrix),
518 w_factor: 1.0,
519 d_x: coeffs.d_x,
520 delta_x: coeffs.delta_x,
521 d_s: coeffs.d_s,
522 delta_s: coeffs.delta_s,
523 j_c: coeffs.j_c,
524 d_c: coeffs.d_c,
525 delta_c: coeffs.delta_c,
526 j_d: coeffs.j_d,
527 d_d: coeffs.d_d,
528 delta_d: coeffs.delta_d,
529 }
530}
531
532fn downcast_dense(v: &dyn Vector) -> &DenseVector {
533 v.as_any()
534 .downcast_ref::<DenseVector>()
535 .expect("LowRankAugSystemSolver currently requires DenseVector RHS/solutions")
536}
537
538fn clone_dense(src: &DenseVector) -> DenseVector {
542 let mut out = src.space().make_new_dense();
543 out.set_values(&src.expanded_values());
544 out
545}
546
547fn zero_x_for(space_x: &Rc<DenseVectorSpace>, lr_w: &LowRankUpdateSymMatrix) -> DenseVector {
548 let _ = lr_w;
553 let mut z = space_x.make_new_dense();
554 z.set(0.0);
555 z
556}
557
558impl AugSystemSolver for LowRankAugSystemSolver {
559 fn provides_inertia(&self) -> bool {
560 self.inner.provides_inertia()
561 }
562
563 fn number_of_neg_evals(&self) -> Index {
564 if self.inner.provides_inertia() {
565 self.inner.number_of_neg_evals()
566 } else {
567 self.num_neg_evals
568 }
569 }
570
571 fn increase_quality(&mut self) -> bool {
572 self.inner.increase_quality()
573 }
574
575 fn last_solve_status(&self) -> ESymSolverStatus {
576 self.inner.last_solve_status()
577 }
578
579 fn set_timing_stats(&mut self, timing: Rc<TimingStatistics>) {
580 self.inner.set_timing_stats(timing);
581 }
582
583 fn solve(
584 &mut self,
585 coeffs: &AugSysCoeffs<'_>,
586 rhs: &AugSysRhs<'_>,
587 sol: &mut AugSysSol<'_>,
588 check_neg_evals: bool,
589 num_neg_evals: Index,
590 ) -> ESymSolverStatus {
591 let mut check_neg_evals = check_neg_evals;
594 if !self.inner.provides_inertia() {
595 check_neg_evals = false;
596 }
597
598 let needs_rebuild = self.first_call || self.augmented_system_requires_change(coeffs);
599 if needs_rebuild {
600 let lr_w = match coeffs.w {
601 Some(w) => w.as_any().downcast_ref::<LowRankUpdateSymMatrix>().expect(
602 "LowRankAugSystemSolver requires a LowRankUpdateSymMatrix as its W block",
603 ),
604 None => panic!("LowRankAugSystemSolver requires a non-None W"),
605 };
606 let status =
607 self.update_factorization(lr_w, coeffs, rhs, check_neg_evals, num_neg_evals);
608 if status != ESymSolverStatus::Success {
609 return status;
610 }
611 self.store_cache(coeffs);
612 self.first_call = false;
613 }
614
615 let ic = inner_coeffs(&self.factor, coeffs);
617 let status = self
618 .inner
619 .solve(&ic, rhs, sol, check_neg_evals, num_neg_evals);
620 if self.inner.provides_inertia() {
621 self.num_neg_evals = self.inner.number_of_neg_evals();
622 }
623 if status != ESymSolverStatus::Success {
624 return status;
625 }
626
627 if self.factor.utilde2_x.is_some() {
630 self.apply_smw(1.0, true, rhs, sol);
631 }
632 if self.factor.vtilde1_x.is_some() {
633 self.apply_smw(-1.0, false, rhs, sol);
634 }
635
636 ESymSolverStatus::Success
637 }
638}
639
640impl LowRankAugSystemSolver {
641 fn apply_smw(&self, sign: Number, use_u: bool, rhs: &AugSysRhs<'_>, sol: &mut AugSysSol<'_>) {
647 let (mvx, mvs, mvc, mvd, j) = if use_u {
648 (
649 self.factor.utilde2_x.as_ref().unwrap(),
650 self.factor.utilde2_s.as_ref().unwrap(),
651 self.factor.utilde2_c.as_ref().unwrap(),
652 self.factor.utilde2_d.as_ref().unwrap(),
653 self.factor.j2.as_ref().unwrap(),
654 )
655 } else {
656 (
657 self.factor.vtilde1_x.as_ref().unwrap(),
658 self.factor.vtilde1_s.as_ref().unwrap(),
659 self.factor.vtilde1_c.as_ref().unwrap(),
660 self.factor.vtilde1_d.as_ref().unwrap(),
661 self.factor.j1.as_ref().unwrap(),
662 )
663 };
664 let n = mvx.n_cols();
665 let mut b_vec: Vec<Number> = Vec::with_capacity(n as usize);
669 for k in 0..n {
670 let dot = mvx.get_vector(k).dot(rhs.rhs_x)
671 + mvs.get_vector(k).dot(rhs.rhs_s)
672 + mvc.get_vector(k).dot(rhs.rhs_c)
673 + mvd.get_vector(k).dot(rhs.rhs_d);
674 b_vec.push(dot);
675 }
676 let space_b = DenseVectorSpace::new(n);
677 let mut b = space_b.make_new_dense();
678 b.set_values(&b_vec);
679 j.cholesky_solve_vector(&mut b);
681 mvx.mult_vector(sign, &b, 1.0, sol.sol_x);
683 mvs.mult_vector(sign, &b, 1.0, sol.sol_s);
684 mvc.mult_vector(sign, &b, 1.0, sol.sol_c);
685 mvd.mult_vector(sign, &b, 1.0, sol.sol_d);
686 }
687}
688
689#[cfg(test)]
690mod tests {
691 use super::*;
692 use pounce_linalg::dense_vector::DenseVectorSpace;
693 use pounce_linalg::low_rank_update_sym_matrix::LowRankUpdateSymMatrixSpace;
694 use std::cell::Cell;
695
696 struct DiagInner {
701 calls: Cell<usize>,
702 }
703 impl AugSystemSolver for DiagInner {
704 fn provides_inertia(&self) -> bool {
705 true
706 }
707 fn number_of_neg_evals(&self) -> Index {
708 0
709 }
710 fn increase_quality(&mut self) -> bool {
711 true
712 }
713 fn last_solve_status(&self) -> ESymSolverStatus {
714 ESymSolverStatus::Success
715 }
716 fn solve(
717 &mut self,
718 coeffs: &AugSysCoeffs<'_>,
719 rhs: &AugSysRhs<'_>,
720 sol: &mut AugSysSol<'_>,
721 _check_neg_evals: bool,
722 _num_neg_evals: Index,
723 ) -> ESymSolverStatus {
724 self.calls.set(self.calls.get() + 1);
725 let wdiag = coeffs
726 .w
727 .expect("DiagInner requires W")
728 .as_any()
729 .downcast_ref::<DiagMatrix>()
730 .expect("DiagInner requires W to be a DiagMatrix");
731 let diag_rc = wdiag.get_diag().expect("Wdiag has no diag set").clone();
732 let diag = downcast_dense(diag_rc.as_ref()).expanded_values();
733 let rhs_x = downcast_dense(rhs.rhs_x).expanded_values();
734 let dx_vals: Option<Vec<Number>> =
735 coeffs.d_x.map(|d| downcast_dense(d).expanded_values());
736 let mut out = vec![0.0; rhs_x.len()];
737 for i in 0..rhs_x.len() {
738 let dx_i = match &dx_vals {
739 Some(v) => v[i],
740 None => 0.0,
741 };
742 let denom = diag[i] + dx_i + coeffs.delta_x;
743 out[i] = rhs_x[i] / denom;
744 }
745 let sol_x_dv = sol
746 .sol_x
747 .as_any_mut()
748 .downcast_mut::<DenseVector>()
749 .unwrap();
750 sol_x_dv.set_values(&out);
751 ESymSolverStatus::Success
753 }
754 }
755
756 fn dvec(space: &Rc<DenseVectorSpace>, vals: &[Number]) -> DenseVector {
757 let mut v = space.make_new_dense();
758 v.set_values(vals);
759 v
760 }
761
762 fn dvec_rc(space: &Rc<DenseVectorSpace>, vals: &[Number]) -> Rc<DenseVector> {
763 Rc::new(dvec(space, vals))
764 }
765
766 #[test]
767 fn smw_recovers_low_rank_inverse() {
768 let space_x = DenseVectorSpace::new(1);
773 let space_zero = DenseVectorSpace::new(0);
774 let lr_space = LowRankUpdateSymMatrixSpace::new(1, None, false);
775 let mut lr = lr_space.make_new_low_rank();
776 let b0_rc: Rc<dyn Vector> = dvec_rc(&space_x, &[2.0]);
777 lr.set_diag(b0_rc);
778 let v_space = MultiVectorMatrixSpace::new(1, Rc::clone(&space_x));
779 let mut v_mvm = v_space.make_new_multi_vector();
780 v_mvm.set_vector(0, dvec_rc(&space_x, &[3.0]) as Rc<dyn Vector>);
781 lr.set_v(Rc::new(v_mvm));
782 let lr_rc: Rc<LowRankUpdateSymMatrix> = Rc::new(lr);
783
784 let mut solver = LowRankAugSystemSolver::new(Box::new(DiagInner {
785 calls: Cell::new(0),
786 }));
787
788 let j_c_space = pounce_linalg::dense_gen_matrix::DenseGenMatrixSpace::new(0, 1);
790 let j_d_space = pounce_linalg::dense_gen_matrix::DenseGenMatrixSpace::new(0, 1);
791 let j_c = j_c_space.make_new_dense_gen();
792 let j_d = j_d_space.make_new_dense_gen();
793
794 let coeffs = AugSysCoeffs {
795 w: Some(lr_rc.as_ref() as &dyn SymMatrix),
796 w_factor: 1.0,
797 d_x: None,
798 delta_x: 0.0,
799 d_s: None,
800 delta_s: 0.0,
801 j_c: &j_c as &dyn Matrix,
802 d_c: None,
803 delta_c: 0.0,
804 j_d: &j_d as &dyn Matrix,
805 d_d: None,
806 delta_d: 0.0,
807 };
808
809 let rhs_x = dvec(&space_x, &[5.0]);
810 let rhs_s = dvec(&space_zero, &[]);
811 let rhs_c = dvec(&space_zero, &[]);
812 let rhs_d = dvec(&space_zero, &[]);
813 let rhs = AugSysRhs {
814 rhs_x: &rhs_x,
815 rhs_s: &rhs_s,
816 rhs_c: &rhs_c,
817 rhs_d: &rhs_d,
818 };
819 let mut sol_x = dvec(&space_x, &[0.0]);
820 let mut sol_s = dvec(&space_zero, &[]);
821 let mut sol_c = dvec(&space_zero, &[]);
822 let mut sol_d = dvec(&space_zero, &[]);
823 let mut sol = AugSysSol {
824 sol_x: &mut sol_x,
825 sol_s: &mut sol_s,
826 sol_c: &mut sol_c,
827 sol_d: &mut sol_d,
828 };
829 let status = solver.solve(&coeffs, &rhs, &mut sol, false, 0);
830 assert_eq!(status, ESymSolverStatus::Success);
831 let got = sol_x.expanded_values()[0];
833 let want = 5.0 / 11.0;
834 assert!((got - want).abs() < 1e-12, "got {} want {}", got, want);
835 }
836
837 #[test]
838 fn smw_with_u_only_applies_positive_correction() {
839 let space_x = DenseVectorSpace::new(1);
842 let space_zero = DenseVectorSpace::new(0);
843 let lr_space = LowRankUpdateSymMatrixSpace::new(1, None, false);
844 let mut lr = lr_space.make_new_low_rank();
845 lr.set_diag(dvec_rc(&space_x, &[5.0]));
846 let u_space = MultiVectorMatrixSpace::new(1, Rc::clone(&space_x));
847 let mut u_mvm = u_space.make_new_multi_vector();
848 u_mvm.set_vector(0, dvec_rc(&space_x, &[1.5]) as Rc<dyn Vector>);
849 lr.set_u(Rc::new(u_mvm));
850 let lr_rc: Rc<LowRankUpdateSymMatrix> = Rc::new(lr);
851
852 let mut solver = LowRankAugSystemSolver::new(Box::new(DiagInner {
853 calls: Cell::new(0),
854 }));
855
856 let j_c_space = pounce_linalg::dense_gen_matrix::DenseGenMatrixSpace::new(0, 1);
857 let j_d_space = pounce_linalg::dense_gen_matrix::DenseGenMatrixSpace::new(0, 1);
858 let j_c = j_c_space.make_new_dense_gen();
859 let j_d = j_d_space.make_new_dense_gen();
860
861 let coeffs = AugSysCoeffs {
862 w: Some(lr_rc.as_ref() as &dyn SymMatrix),
863 w_factor: 1.0,
864 d_x: None,
865 delta_x: 0.0,
866 d_s: None,
867 delta_s: 0.0,
868 j_c: &j_c as &dyn Matrix,
869 d_c: None,
870 delta_c: 0.0,
871 j_d: &j_d as &dyn Matrix,
872 d_d: None,
873 delta_d: 0.0,
874 };
875
876 let rhs_x = dvec(&space_x, &[7.0]);
877 let rhs_s = dvec(&space_zero, &[]);
878 let rhs_c = dvec(&space_zero, &[]);
879 let rhs_d = dvec(&space_zero, &[]);
880 let rhs = AugSysRhs {
881 rhs_x: &rhs_x,
882 rhs_s: &rhs_s,
883 rhs_c: &rhs_c,
884 rhs_d: &rhs_d,
885 };
886 let mut sol_x = dvec(&space_x, &[0.0]);
887 let mut sol_s = dvec(&space_zero, &[]);
888 let mut sol_c = dvec(&space_zero, &[]);
889 let mut sol_d = dvec(&space_zero, &[]);
890 let mut sol = AugSysSol {
891 sol_x: &mut sol_x,
892 sol_s: &mut sol_s,
893 sol_c: &mut sol_c,
894 sol_d: &mut sol_d,
895 };
896 let status = solver.solve(&coeffs, &rhs, &mut sol, false, 0);
897 assert_eq!(status, ESymSolverStatus::Success);
898 let got = sol_x.expanded_values()[0];
900 let want = 7.0 / 2.75;
901 assert!((got - want).abs() < 1e-12, "got {} want {}", got, want);
902 }
903
904 #[test]
905 fn smw_with_v_and_u_combines_corrections() {
906 let space_x = DenseVectorSpace::new(1);
909 let space_zero = DenseVectorSpace::new(0);
910 let lr_space = LowRankUpdateSymMatrixSpace::new(1, None, false);
911 let mut lr = lr_space.make_new_low_rank();
912 lr.set_diag(dvec_rc(&space_x, &[10.0]));
913 let v_space = MultiVectorMatrixSpace::new(1, Rc::clone(&space_x));
914 let mut v_mvm = v_space.make_new_multi_vector();
915 v_mvm.set_vector(0, dvec_rc(&space_x, &[2.0]) as Rc<dyn Vector>);
916 lr.set_v(Rc::new(v_mvm));
917 let u_space = MultiVectorMatrixSpace::new(1, Rc::clone(&space_x));
918 let mut u_mvm = u_space.make_new_multi_vector();
919 u_mvm.set_vector(0, dvec_rc(&space_x, &[1.0]) as Rc<dyn Vector>);
920 lr.set_u(Rc::new(u_mvm));
921 let lr_rc: Rc<LowRankUpdateSymMatrix> = Rc::new(lr);
922
923 let mut solver = LowRankAugSystemSolver::new(Box::new(DiagInner {
924 calls: Cell::new(0),
925 }));
926
927 let j_c_space = pounce_linalg::dense_gen_matrix::DenseGenMatrixSpace::new(0, 1);
928 let j_d_space = pounce_linalg::dense_gen_matrix::DenseGenMatrixSpace::new(0, 1);
929 let j_c = j_c_space.make_new_dense_gen();
930 let j_d = j_d_space.make_new_dense_gen();
931
932 let coeffs = AugSysCoeffs {
933 w: Some(lr_rc.as_ref() as &dyn SymMatrix),
934 w_factor: 1.0,
935 d_x: None,
936 delta_x: 0.0,
937 d_s: None,
938 delta_s: 0.0,
939 j_c: &j_c as &dyn Matrix,
940 d_c: None,
941 delta_c: 0.0,
942 j_d: &j_d as &dyn Matrix,
943 d_d: None,
944 delta_d: 0.0,
945 };
946
947 let rhs_x = dvec(&space_x, &[1.0]);
948 let rhs_s = dvec(&space_zero, &[]);
949 let rhs_c = dvec(&space_zero, &[]);
950 let rhs_d = dvec(&space_zero, &[]);
951 let rhs = AugSysRhs {
952 rhs_x: &rhs_x,
953 rhs_s: &rhs_s,
954 rhs_c: &rhs_c,
955 rhs_d: &rhs_d,
956 };
957 let mut sol_x = dvec(&space_x, &[0.0]);
958 let mut sol_s = dvec(&space_zero, &[]);
959 let mut sol_c = dvec(&space_zero, &[]);
960 let mut sol_d = dvec(&space_zero, &[]);
961 let mut sol = AugSysSol {
962 sol_x: &mut sol_x,
963 sol_s: &mut sol_s,
964 sol_c: &mut sol_c,
965 sol_d: &mut sol_d,
966 };
967 let status = solver.solve(&coeffs, &rhs, &mut sol, false, 0);
968 assert_eq!(status, ESymSolverStatus::Success);
969 let got = sol_x.expanded_values()[0];
971 let want = 1.0 / 13.0;
972 assert!((got - want).abs() < 1e-12, "got {} want {}", got, want);
973 }
974
975 #[test]
976 fn unchanged_coeffs_skip_rebuild_after_first_call() {
977 let mut lr_solver = LowRankAugSystemSolver::new(Box::new(DiagInner {
978 calls: Cell::new(0),
979 }));
980 let space_x = DenseVectorSpace::new(1);
981 let space_zero = DenseVectorSpace::new(0);
982 let lr_space = LowRankUpdateSymMatrixSpace::new(1, None, false);
983 let mut lr = lr_space.make_new_low_rank();
984 lr.set_diag(dvec_rc(&space_x, &[2.0]));
985 let lr_rc: Rc<LowRankUpdateSymMatrix> = Rc::new(lr);
986 let j_c_space = pounce_linalg::dense_gen_matrix::DenseGenMatrixSpace::new(0, 1);
987 let j_d_space = pounce_linalg::dense_gen_matrix::DenseGenMatrixSpace::new(0, 1);
988 let j_c = j_c_space.make_new_dense_gen();
989 let j_d = j_d_space.make_new_dense_gen();
990 let coeffs = AugSysCoeffs {
991 w: Some(lr_rc.as_ref() as &dyn SymMatrix),
992 w_factor: 1.0,
993 d_x: None,
994 delta_x: 0.001,
995 d_s: None,
996 delta_s: 0.0,
997 j_c: &j_c as &dyn Matrix,
998 d_c: None,
999 delta_c: 0.0,
1000 j_d: &j_d as &dyn Matrix,
1001 d_d: None,
1002 delta_d: 0.0,
1003 };
1004 let rhs_x = dvec(&space_x, &[1.0]);
1005 let rhs_zero = dvec(&space_zero, &[]);
1006 let rhs = AugSysRhs {
1007 rhs_x: &rhs_x,
1008 rhs_s: &rhs_zero,
1009 rhs_c: &rhs_zero,
1010 rhs_d: &rhs_zero,
1011 };
1012 let mut sol_x = dvec(&space_x, &[0.0]);
1013 let mut sol_z1 = dvec(&space_zero, &[]);
1014 let mut sol_z2 = dvec(&space_zero, &[]);
1015 let mut sol_z3 = dvec(&space_zero, &[]);
1016 {
1017 let mut sol = AugSysSol {
1018 sol_x: &mut sol_x,
1019 sol_s: &mut sol_z1,
1020 sol_c: &mut sol_z2,
1021 sol_d: &mut sol_z3,
1022 };
1023 lr_solver.solve(&coeffs, &rhs, &mut sol, false, 0);
1024 }
1025 assert!(!lr_solver.augmented_system_requires_change(&coeffs));
1027 }
1028}