Skip to main content

pounce_algorithm/kkt/
low_rank_aug_system_solver.rs

1//! Low-rank augmented system solver — port of
2//! `Algorithm/IpLowRankAugSystemSolver.{hpp,cpp}`.
3//!
4//! Wraps another [`AugSystemSolver`] and exploits a [`LowRankUpdateSymMatrix`]
5//! Hessian via the Sherman-Morrison-Woodbury identity. The wrapped
6//! solver factorizes the diagonal part `B0`; this solver applies the
7//! rank-`(nV + nU)` correction using cached
8//! `Vtilde1 = K⁻¹ V` and `Utilde2 = K⁻¹ U − Vtilde1·(J1^{-T}J1^{-1}·Vtilde1ᵀU)`
9//! plus their dense Cholesky factors `J1 = chol(I + Vtilde1ᵀ V)` and
10//! `J2 = chol(I − Utilde2ᵀ U)`.
11//!
12//! The augmented-system solution comes from upstream's recipe
13//! (`IpLowRankAugSystemSolver.cpp:179-228`):
14//!
15//! 1. inner solver factors `K` (the aug system with `Wdiag` in place
16//!    of `W`) and back-substitutes for `csol_diag = K⁻¹ rhs`.
17//! 2. If `Utilde2_` is set, apply  `csol += Utilde2 · J2⁻¹ J2⁻ᵀ · Utilde2ᵀ rhs`.
18//! 3. If `Vtilde1_` is set, apply  `csol −= Vtilde1 · J1⁻¹ J1⁻ᵀ · Vtilde1ᵀ rhs`.
19//!
20//! `Vtilde1` and `Utilde2` are stored as four separate per-block
21//! [`MultiVectorMatrix`]es (x, s, c, d) — the same data that upstream
22//! packs into a 4-component `CompoundVector` of dense columns. This
23//! keeps the SMW arithmetic in dense linalg without needing a
24//! compound-vector storage class.
25
26use crate::kkt::aug_system_solver::{AugSysCoeffs, AugSysRhs, AugSysSol, AugSystemSolver};
27use pounce_common::tagged::Tag;
28use pounce_common::timing::TimingStatistics;
29use pounce_common::types::{Index, Number};
30use pounce_linalg::dense_gen_matrix::{DenseGenMatrix, DenseGenMatrixSpace};
31use pounce_linalg::dense_sym_matrix::DenseSymMatrixSpace;
32use pounce_linalg::dense_vector::{DenseVector, DenseVectorSpace};
33use pounce_linalg::diag_matrix::DiagMatrix;
34use pounce_linalg::low_rank_update_sym_matrix::LowRankUpdateSymMatrix;
35use pounce_linalg::multi_vector_matrix::{MultiVectorMatrix, MultiVectorMatrixSpace};
36use pounce_linalg::{Matrix, SymMatrix, Vector};
37use pounce_linsol::ESymSolverStatus;
38use std::rc::Rc;
39
40pub struct LowRankAugSystemSolver {
41    /// Inner solver that owns the diagonal factorization.
42    inner: Box<dyn AugSystemSolver>,
43    /// Whether `solve` has been called yet.
44    first_call: bool,
45    /// Cached negative-eigenvalue count.
46    num_neg_evals: Index,
47    /// Tag/scalar cache mirroring upstream's per-coefficient state.
48    cache: AugSysCache,
49    /// SMW factorization state (cleared on each rebuild).
50    factor: Factorization,
51}
52
53#[derive(Debug, Clone)]
54pub struct AugSysCache {
55    pub w_tag: Tag,
56    pub w_factor: Number,
57    pub d_x_tag: Tag,
58    pub delta_x: Number,
59    pub d_s_tag: Tag,
60    pub delta_s: Number,
61    pub j_c_tag: Tag,
62    pub d_c_tag: Tag,
63    pub delta_c: Number,
64    pub j_d_tag: Tag,
65    pub d_d_tag: Tag,
66    pub delta_d: Number,
67}
68
69impl Default for AugSysCache {
70    fn default() -> Self {
71        Self {
72            w_tag: Tag::NONE,
73            w_factor: 0.0,
74            d_x_tag: Tag::NONE,
75            delta_x: 0.0,
76            d_s_tag: Tag::NONE,
77            delta_s: 0.0,
78            j_c_tag: Tag::NONE,
79            d_c_tag: Tag::NONE,
80            delta_c: 0.0,
81            j_d_tag: Tag::NONE,
82            d_d_tag: Tag::NONE,
83            delta_d: 0.0,
84        }
85    }
86}
87
88#[derive(Default)]
89struct Factorization {
90    /// `Wdiag` substituted for `W` in every inner-solver call. Mirrors
91    /// upstream `Wdiag_`. Held mutably so we can call
92    /// [`DiagMatrix::set_diag`] on rebuild.
93    wdiag: Option<Box<DiagMatrix>>,
94    /// Dense Cholesky `J1 = chol(I + Vtilde1ᵀ · V)`. None when V is empty.
95    j1: Option<DenseGenMatrix>,
96    /// Dense Cholesky `J2 = chol(I − Utilde2ᵀ · U)`. None when U is empty.
97    j2: Option<DenseGenMatrix>,
98    /// Per-block `Vtilde1` storage (rank `nV`).
99    vtilde1_x: Option<MultiVectorMatrix>,
100    vtilde1_s: Option<MultiVectorMatrix>,
101    vtilde1_c: Option<MultiVectorMatrix>,
102    vtilde1_d: Option<MultiVectorMatrix>,
103    /// Per-block `Utilde2` storage (rank `nU`).
104    utilde2_x: Option<MultiVectorMatrix>,
105    utilde2_s: Option<MultiVectorMatrix>,
106    utilde2_c: Option<MultiVectorMatrix>,
107    utilde2_d: Option<MultiVectorMatrix>,
108}
109
110impl LowRankAugSystemSolver {
111    pub fn new(inner: Box<dyn AugSystemSolver>) -> Self {
112        Self {
113            inner,
114            first_call: true,
115            num_neg_evals: 0,
116            cache: AugSysCache::default(),
117            factor: Factorization::default(),
118        }
119    }
120
121    /// Pure tag/scalar comparison — port of upstream
122    /// `AugmentedSystemRequiresChange` (`IpLowRankAugSystemSolver.cpp:531-599`).
123    pub fn augmented_system_requires_change(&self, coeffs: &AugSysCoeffs<'_>) -> bool {
124        let cache = &self.cache;
125        let zero_tag: Tag = Tag::NONE;
126
127        let w_changed = match coeffs.w {
128            Some(w) => w.as_tagged().get_tag() != cache.w_tag,
129            None => cache.w_tag != zero_tag,
130        };
131        if w_changed || coeffs.w_factor != cache.w_factor {
132            return true;
133        }
134        let dx_changed = match coeffs.d_x {
135            Some(d) => d.as_tagged().get_tag() != cache.d_x_tag,
136            None => cache.d_x_tag != zero_tag,
137        };
138        if dx_changed || coeffs.delta_x != cache.delta_x {
139            return true;
140        }
141        let ds_changed = match coeffs.d_s {
142            Some(d) => d.as_tagged().get_tag() != cache.d_s_tag,
143            None => cache.d_s_tag != zero_tag,
144        };
145        if ds_changed || coeffs.delta_s != cache.delta_s {
146            return true;
147        }
148        if coeffs.j_c.as_tagged().get_tag() != cache.j_c_tag {
149            return true;
150        }
151        let dc_changed = match coeffs.d_c {
152            Some(d) => d.as_tagged().get_tag() != cache.d_c_tag,
153            None => cache.d_c_tag != zero_tag,
154        };
155        if dc_changed || coeffs.delta_c != cache.delta_c {
156            return true;
157        }
158        if coeffs.j_d.as_tagged().get_tag() != cache.j_d_tag {
159            return true;
160        }
161        let dd_changed = match coeffs.d_d {
162            Some(d) => d.as_tagged().get_tag() != cache.d_d_tag,
163            None => cache.d_d_tag != zero_tag,
164        };
165        if dd_changed || coeffs.delta_d != cache.delta_d {
166            return true;
167        }
168        false
169    }
170
171    fn store_cache(&mut self, coeffs: &AugSysCoeffs<'_>) {
172        let zero_tag = Tag::NONE;
173        self.cache.w_tag = coeffs
174            .w
175            .map(|w| w.as_tagged().get_tag())
176            .unwrap_or(zero_tag);
177        self.cache.w_factor = coeffs.w_factor;
178        self.cache.d_x_tag = coeffs
179            .d_x
180            .map(|d| d.as_tagged().get_tag())
181            .unwrap_or(zero_tag);
182        self.cache.delta_x = coeffs.delta_x;
183        self.cache.d_s_tag = coeffs
184            .d_s
185            .map(|d| d.as_tagged().get_tag())
186            .unwrap_or(zero_tag);
187        self.cache.delta_s = coeffs.delta_s;
188        self.cache.j_c_tag = coeffs.j_c.as_tagged().get_tag();
189        self.cache.d_c_tag = coeffs
190            .d_c
191            .map(|d| d.as_tagged().get_tag())
192            .unwrap_or(zero_tag);
193        self.cache.delta_c = coeffs.delta_c;
194        self.cache.j_d_tag = coeffs.j_d.as_tagged().get_tag();
195        self.cache.d_d_tag = coeffs
196            .d_d
197            .map(|d| d.as_tagged().get_tag())
198            .unwrap_or(zero_tag);
199        self.cache.delta_d = coeffs.delta_d;
200    }
201
202    pub fn first_call(&self) -> bool {
203        self.first_call
204    }
205
206    pub fn cache(&self) -> &AugSysCache {
207        &self.cache
208    }
209
210    /// Rebuild `Wdiag`, `Vtilde1`, `Utilde2`, `J1`, `J2` from a fresh
211    /// LR Hessian. Matches `IpLowRankAugSystemSolver.cpp::UpdateFactorization`
212    /// (lines 233-404). Returns the inner-solver's status — on
213    /// `WrongInertia` from a Cholesky failure, increments
214    /// `num_neg_evals` so the upper layer (PerturbationHandler) sees a
215    /// distinct retry target.
216    fn update_factorization(
217        &mut self,
218        lr_w: &LowRankUpdateSymMatrix,
219        coeffs: &AugSysCoeffs<'_>,
220        proto: &AugSysRhs<'_>,
221        check_neg_evals: bool,
222        num_neg_evals: Index,
223    ) -> ESymSolverStatus {
224        let proto_x = downcast_dense(proto.rhs_x);
225        let proto_s = downcast_dense(proto.rhs_s);
226        let proto_c = downcast_dense(proto.rhs_c);
227        let proto_d = downcast_dense(proto.rhs_d);
228        let space_x = Rc::clone(proto_x.space());
229        let space_s = Rc::clone(proto_s.space());
230        let space_c = Rc::clone(proto_c.space());
231        let space_d = Rc::clone(proto_d.space());
232
233        // 1. Build Wdiag from B0 (with optional P_LM expansion when
234        //    `reduced_diag` is set). When w_factor != 1.0, B0 is treated
235        //    as zero per upstream `IpLowRankAugSystemSolver.cpp:268-272`.
236        let b0_dense: DenseVector = if coeffs.w_factor == 1.0 {
237            match lr_w.get_diag() {
238                Some(d) => clone_dense(downcast_dense(d.as_ref())),
239                None => zero_x_for(&space_x, lr_w),
240            }
241        } else {
242            zero_x_for(&space_x, lr_w)
243        };
244
245        let wdiag_diag: Rc<dyn Vector> = match (lr_w.p_lowrank(), lr_w.reduced_diag()) {
246            (Some(p_lm), true) => {
247                // fullx = P_LM · B0
248                let mut fullx = space_x.make_new_dense();
249                p_lm.mult_vector(1.0, &b0_dense, 0.0, &mut fullx);
250                Rc::new(fullx) as Rc<dyn Vector>
251            }
252            _ => Rc::new(clone_dense(&b0_dense)) as Rc<dyn Vector>,
253        };
254        let mut wdiag = Box::new(DiagMatrix::new(space_x.dim()));
255        wdiag.set_diag(wdiag_diag);
256        self.factor.wdiag = Some(wdiag);
257
258        // 2. SolveMultiVector for V → Vtilde1 = K⁻¹ V (per-block).
259        if coeffs.w_factor == 1.0 && lr_w.get_v().is_some() {
260            let v = Rc::clone(lr_w.get_v().unwrap());
261            let n_v = v.n_cols();
262
263            // Build V_x: each column is either V[:,k] directly (no P_LM)
264            // or P_LM · V[:,k]. We need V_x for the M1 update; we keep
265            // it on the stack here.
266            let v_x_space = MultiVectorMatrixSpace::new(n_v, Rc::clone(&space_x));
267            let mut v_x = v_x_space.make_new_multi_vector();
268            for k in 0..n_v {
269                let vk = Rc::clone(v.get_vector(k));
270                let rhs_x_k: Rc<dyn Vector> = match lr_w.p_lowrank() {
271                    Some(p_lm) => {
272                        let mut fullx = space_x.make_new_dense();
273                        p_lm.mult_vector(1.0, vk.as_ref(), 0.0, &mut fullx);
274                        Rc::new(fullx) as Rc<dyn Vector>
275                    }
276                    None => vk,
277                };
278                v_x.set_vector(k, rhs_x_k);
279            }
280
281            let (vt_x, vt_s, vt_c, vt_d) = self.multi_solve_block(
282                &v_x,
283                coeffs,
284                &space_x,
285                &space_s,
286                &space_c,
287                &space_d,
288                check_neg_evals,
289                num_neg_evals,
290            );
291
292            let vt_x = match vt_x {
293                Ok(x) => x,
294                Err(status) => return status,
295            };
296
297            // 3. M1 = I + Vtilde1_x^T · V_x; J1 = chol(M1).
298            let m1_space = DenseSymMatrixSpace::new(n_v);
299            let mut m1 = m1_space.make_new_dense_sym();
300            m1.fill_identity(1.0);
301            m1.high_rank_update_transpose(1.0, &vt_x, &v_x, 1.0);
302            let j1_space = DenseGenMatrixSpace::new(n_v, n_v);
303            let mut j1 = j1_space.make_new_dense_gen();
304            if !j1.compute_cholesky_factor(&m1) {
305                self.num_neg_evals += 1;
306                return ESymSolverStatus::WrongInertia;
307            }
308            self.factor.vtilde1_x = Some(vt_x);
309            self.factor.vtilde1_s = Some(vt_s);
310            self.factor.vtilde1_c = Some(vt_c);
311            self.factor.vtilde1_d = Some(vt_d);
312            self.factor.j1 = Some(j1);
313        } else {
314            self.factor.vtilde1_x = None;
315            self.factor.vtilde1_s = None;
316            self.factor.vtilde1_c = None;
317            self.factor.vtilde1_d = None;
318            self.factor.j1 = None;
319        }
320
321        // 4. SolveMultiVector for U → Utilde1 = K⁻¹ U; orthogonalize
322        //    against Vtilde1 (if present) to get Utilde2.
323        if coeffs.w_factor == 1.0 && lr_w.get_u().is_some() {
324            let u = Rc::clone(lr_w.get_u().unwrap());
325            let n_u = u.n_cols();
326
327            let u_x_space = MultiVectorMatrixSpace::new(n_u, Rc::clone(&space_x));
328            let mut u_x = u_x_space.make_new_multi_vector();
329            for k in 0..n_u {
330                let uk = Rc::clone(u.get_vector(k));
331                let rhs_x_k: Rc<dyn Vector> = match lr_w.p_lowrank() {
332                    Some(p_lm) => {
333                        let mut fullx = space_x.make_new_dense();
334                        p_lm.mult_vector(1.0, uk.as_ref(), 0.0, &mut fullx);
335                        Rc::new(fullx) as Rc<dyn Vector>
336                    }
337                    None => uk,
338                };
339                u_x.set_vector(k, rhs_x_k);
340            }
341
342            let (mut ut_x, mut ut_s, mut ut_c, mut ut_d) = match self.multi_solve_block(
343                &u_x,
344                coeffs,
345                &space_x,
346                &space_s,
347                &space_c,
348                &space_d,
349                check_neg_evals,
350                num_neg_evals,
351            ) {
352                (Ok(x), s, c, d) => (x, s, c, d),
353                (Err(status), _, _, _) => return status,
354            };
355
356            // 5. If Vtilde1 is present: Utilde2 = Utilde1 − Vtilde1 · (J1⁻¹J1⁻ᵀ · Vtilde1ᵀU).
357            if self.factor.vtilde1_x.is_some() {
358                let vt1_x = self.factor.vtilde1_x.as_ref().unwrap();
359                let vt1_s = self.factor.vtilde1_s.as_ref().unwrap();
360                let vt1_c = self.factor.vtilde1_c.as_ref().unwrap();
361                let vt1_d = self.factor.vtilde1_d.as_ref().unwrap();
362                let n_v = vt1_x.n_cols();
363                // C = Vtilde1_x^T · U_x  (n_v × n_u; HighRankUpdateTranspose's
364                // generic-matrix variant — we synthesize via column dot products
365                // since DenseGenMatrix doesn't expose a high_rank_update_transpose).
366                let c_space = DenseGenMatrixSpace::new(n_v, n_u);
367                let mut c_mat = c_space.make_new_dense_gen();
368                {
369                    let cv = c_mat.values_mut();
370                    for j in 0..n_u as usize {
371                        let uj = u_x.get_vector(j as Index).as_ref();
372                        for i in 0..n_v as usize {
373                            let vi = vt1_x.get_vector(i as Index).as_ref();
374                            cv[i + j * n_v as usize] = vi.dot(uj);
375                        }
376                    }
377                }
378                self.factor
379                    .j1
380                    .as_ref()
381                    .unwrap()
382                    .cholesky_solve_matrix(&mut c_mat);
383                ut_x.add_right_mult_matrix(-1.0, vt1_x, &c_mat, 1.0);
384                ut_s.add_right_mult_matrix(-1.0, vt1_s, &c_mat, 1.0);
385                ut_c.add_right_mult_matrix(-1.0, vt1_c, &c_mat, 1.0);
386                ut_d.add_right_mult_matrix(-1.0, vt1_d, &c_mat, 1.0);
387            }
388
389            // 6. M2 = I − Utilde2_x^T · U_x; J2 = chol(M2).
390            let m2_space = DenseSymMatrixSpace::new(n_u);
391            let mut m2 = m2_space.make_new_dense_sym();
392            m2.fill_identity(1.0);
393            m2.high_rank_update_transpose(-1.0, &ut_x, &u_x, 1.0);
394            let j2_space = DenseGenMatrixSpace::new(n_u, n_u);
395            let mut j2 = j2_space.make_new_dense_gen();
396            if !j2.compute_cholesky_factor(&m2) {
397                self.num_neg_evals += 1;
398                return ESymSolverStatus::WrongInertia;
399            }
400            self.factor.utilde2_x = Some(ut_x);
401            self.factor.utilde2_s = Some(ut_s);
402            self.factor.utilde2_c = Some(ut_c);
403            self.factor.utilde2_d = Some(ut_d);
404            self.factor.j2 = Some(j2);
405        } else {
406            self.factor.utilde2_x = None;
407            self.factor.utilde2_s = None;
408            self.factor.utilde2_c = None;
409            self.factor.utilde2_d = None;
410            self.factor.j2 = None;
411        }
412
413        ESymSolverStatus::Success
414    }
415
416    /// Solve `K · Vtilde = [V_x; 0; 0; 0]` for one block of right-hand
417    /// sides packed in `v_x` (dense column-by-column). Returns the four
418    /// per-block columns of `Vtilde`. Mirrors the inner loop of
419    /// upstream `SolveMultiVector` (`IpLowRankAugSystemSolver.cpp:406-528`).
420    #[allow(clippy::too_many_arguments)]
421    fn multi_solve_block(
422        &mut self,
423        v_x: &MultiVectorMatrix,
424        coeffs: &AugSysCoeffs<'_>,
425        space_x: &Rc<DenseVectorSpace>,
426        space_s: &Rc<DenseVectorSpace>,
427        space_c: &Rc<DenseVectorSpace>,
428        space_d: &Rc<DenseVectorSpace>,
429        check_neg_evals: bool,
430        num_neg_evals: Index,
431    ) -> (
432        Result<MultiVectorMatrix, ESymSolverStatus>,
433        MultiVectorMatrix,
434        MultiVectorMatrix,
435        MultiVectorMatrix,
436    ) {
437        let n_cols = v_x.n_cols();
438
439        // Allocate four per-block result MVMs.
440        let mut out_x =
441            MultiVectorMatrixSpace::new(n_cols, Rc::clone(space_x)).make_new_multi_vector();
442        let mut out_s =
443            MultiVectorMatrixSpace::new(n_cols, Rc::clone(space_s)).make_new_multi_vector();
444        let mut out_c =
445            MultiVectorMatrixSpace::new(n_cols, Rc::clone(space_c)).make_new_multi_vector();
446        let mut out_d =
447            MultiVectorMatrixSpace::new(n_cols, Rc::clone(space_d)).make_new_multi_vector();
448        out_x.fill_with_new_vectors();
449        out_s.fill_with_new_vectors();
450        out_c.fill_with_new_vectors();
451        out_d.fill_with_new_vectors();
452
453        // Allocate zero RHS slots once; the four columns are reused
454        // because we re-zero per call.
455        let mut rhs_s = space_s.make_new_dense();
456        rhs_s.set(0.0);
457        let mut rhs_c = space_c.make_new_dense();
458        rhs_c.set(0.0);
459        let mut rhs_d = space_d.make_new_dense();
460        rhs_d.set(0.0);
461
462        for k in 0..n_cols {
463            let rhs_x_dyn: &dyn Vector = v_x.get_vector(k).as_ref();
464            let inner_rhs = AugSysRhs {
465                rhs_x: rhs_x_dyn,
466                rhs_s: rhs_s.as_dyn_vector(),
467                rhs_c: rhs_c.as_dyn_vector(),
468                rhs_d: rhs_d.as_dyn_vector(),
469            };
470            // Build solution slots (fresh each iteration).
471            let mut sol_x = space_x.make_new_dense();
472            let mut sol_s = space_s.make_new_dense();
473            let mut sol_c = space_c.make_new_dense();
474            let mut sol_d = space_d.make_new_dense();
475            sol_x.set(0.0);
476            sol_s.set(0.0);
477            sol_c.set(0.0);
478            sol_d.set(0.0);
479            let inner_coeffs = inner_coeffs(&self.factor, coeffs);
480            let status = {
481                let mut sol = AugSysSol {
482                    sol_x: &mut sol_x,
483                    sol_s: &mut sol_s,
484                    sol_c: &mut sol_c,
485                    sol_d: &mut sol_d,
486                };
487                self.inner.solve(
488                    &inner_coeffs,
489                    &inner_rhs,
490                    &mut sol,
491                    check_neg_evals,
492                    num_neg_evals,
493                )
494            };
495            if self.inner.provides_inertia() {
496                self.num_neg_evals = self.inner.number_of_neg_evals();
497            }
498            if status != ESymSolverStatus::Success {
499                return (Err(status), out_s, out_c, out_d);
500            }
501            out_x.set_vector(k, Rc::new(sol_x) as Rc<dyn Vector>);
502            out_s.set_vector(k, Rc::new(sol_s) as Rc<dyn Vector>);
503            out_c.set_vector(k, Rc::new(sol_c) as Rc<dyn Vector>);
504            out_d.set_vector(k, Rc::new(sol_d) as Rc<dyn Vector>);
505        }
506        (Ok(out_x), out_s, out_c, out_d)
507    }
508}
509
510/// Build inner-solver coefficients that substitute `Wdiag` for `W`.
511/// Free function (rather than method on `LowRankAugSystemSolver`) so
512/// the borrow is on `&Factorization` only — leaving `self.inner`
513/// available for `&mut`.
514fn inner_coeffs<'b>(factor: &'b Factorization, coeffs: &AugSysCoeffs<'b>) -> AugSysCoeffs<'b> {
515    let wdiag: &DiagMatrix = factor.wdiag.as_ref().expect("Wdiag unset").as_ref();
516    AugSysCoeffs {
517        w: Some(wdiag as &dyn SymMatrix),
518        w_factor: 1.0,
519        d_x: coeffs.d_x,
520        delta_x: coeffs.delta_x,
521        d_s: coeffs.d_s,
522        delta_s: coeffs.delta_s,
523        j_c: coeffs.j_c,
524        d_c: coeffs.d_c,
525        delta_c: coeffs.delta_c,
526        j_d: coeffs.j_d,
527        d_d: coeffs.d_d,
528        delta_d: coeffs.delta_d,
529    }
530}
531
532fn downcast_dense(v: &dyn Vector) -> &DenseVector {
533    v.as_any()
534        .downcast_ref::<DenseVector>()
535        .expect("LowRankAugSystemSolver currently requires DenseVector RHS/solutions")
536}
537
538/// `DenseVector` doesn't implement `Clone`; this builds a fresh dense
539/// vector in the same space populated with the same expanded values.
540/// Cheap when the source is homogeneous.
541fn clone_dense(src: &DenseVector) -> DenseVector {
542    let mut out = src.space().make_new_dense();
543    out.set_values(&src.expanded_values());
544    out
545}
546
547fn zero_x_for(space_x: &Rc<DenseVectorSpace>, lr_w: &LowRankUpdateSymMatrix) -> DenseVector {
548    // `MakeNew` either from the LR vector space (when reduced_diag is
549    // active) or from the proto x-space. We don't have the LR vector
550    // space surfaced directly, but B0 lives in either space; passing
551    // None always means "no diag" so we just return a zero in space_x.
552    let _ = lr_w;
553    let mut z = space_x.make_new_dense();
554    z.set(0.0);
555    z
556}
557
558impl AugSystemSolver for LowRankAugSystemSolver {
559    fn provides_inertia(&self) -> bool {
560        self.inner.provides_inertia()
561    }
562
563    fn number_of_neg_evals(&self) -> Index {
564        if self.inner.provides_inertia() {
565            self.inner.number_of_neg_evals()
566        } else {
567            self.num_neg_evals
568        }
569    }
570
571    fn increase_quality(&mut self) -> bool {
572        self.inner.increase_quality()
573    }
574
575    fn last_solve_status(&self) -> ESymSolverStatus {
576        self.inner.last_solve_status()
577    }
578
579    fn set_timing_stats(&mut self, timing: Rc<TimingStatistics>) {
580        self.inner.set_timing_stats(timing);
581    }
582
583    fn solve(
584        &mut self,
585        coeffs: &AugSysCoeffs<'_>,
586        rhs: &AugSysRhs<'_>,
587        sol: &mut AugSysSol<'_>,
588        check_neg_evals: bool,
589        num_neg_evals: Index,
590    ) -> ESymSolverStatus {
591        // Skip inertia checks when the inner solver doesn't provide
592        // them — mirrors `IpLowRankAugSystemSolver.cpp:102-105`.
593        let mut check_neg_evals = check_neg_evals;
594        if !self.inner.provides_inertia() {
595            check_neg_evals = false;
596        }
597
598        let needs_rebuild = self.first_call || self.augmented_system_requires_change(coeffs);
599        if needs_rebuild {
600            let lr_w = match coeffs.w {
601                Some(w) => w.as_any().downcast_ref::<LowRankUpdateSymMatrix>().expect(
602                    "LowRankAugSystemSolver requires a LowRankUpdateSymMatrix as its W block",
603                ),
604                None => panic!("LowRankAugSystemSolver requires a non-None W"),
605            };
606            let status =
607                self.update_factorization(lr_w, coeffs, rhs, check_neg_evals, num_neg_evals);
608            if status != ESymSolverStatus::Success {
609                return status;
610            }
611            self.store_cache(coeffs);
612            self.first_call = false;
613        }
614
615        // 1. Diagonal solve through the inner aug-system solver.
616        let ic = inner_coeffs(&self.factor, coeffs);
617        let status = self
618            .inner
619            .solve(&ic, rhs, sol, check_neg_evals, num_neg_evals);
620        if self.inner.provides_inertia() {
621            self.num_neg_evals = self.inner.number_of_neg_evals();
622        }
623        if status != ESymSolverStatus::Success {
624            return status;
625        }
626
627        // 2. SMW correction terms — mirror upstream's order:
628        //    apply Utilde2 first, then Vtilde1 (cpp:210-227).
629        if self.factor.utilde2_x.is_some() {
630            self.apply_smw(/*sign=*/ 1.0, /*use_u=*/ true, rhs, sol);
631        }
632        if self.factor.vtilde1_x.is_some() {
633            self.apply_smw(/*sign=*/ -1.0, /*use_u=*/ false, rhs, sol);
634        }
635
636        ESymSolverStatus::Success
637    }
638}
639
640impl LowRankAugSystemSolver {
641    /// Apply one SMW correction step:
642    ///   `b = U_or_Vᵀ · rhs;  J⁻¹J⁻ᵀ b;  sol += sign · U_or_V · b`
643    ///
644    /// `use_u = true` selects `(Utilde2, J2, +1)`; `false` selects
645    /// `(Vtilde1, J1, −1)` (sign passed in by caller).
646    fn apply_smw(&self, sign: Number, use_u: bool, rhs: &AugSysRhs<'_>, sol: &mut AugSysSol<'_>) {
647        let (mvx, mvs, mvc, mvd, j) = if use_u {
648            (
649                self.factor.utilde2_x.as_ref().unwrap(),
650                self.factor.utilde2_s.as_ref().unwrap(),
651                self.factor.utilde2_c.as_ref().unwrap(),
652                self.factor.utilde2_d.as_ref().unwrap(),
653                self.factor.j2.as_ref().unwrap(),
654            )
655        } else {
656            (
657                self.factor.vtilde1_x.as_ref().unwrap(),
658                self.factor.vtilde1_s.as_ref().unwrap(),
659                self.factor.vtilde1_c.as_ref().unwrap(),
660                self.factor.vtilde1_d.as_ref().unwrap(),
661                self.factor.j1.as_ref().unwrap(),
662            )
663        };
664        let n = mvx.n_cols();
665        // Build `b = M^T · crhs` from the four blocks. Reduction order
666        // matches upstream's CompoundVector dot, which iterates blocks
667        // in the order x, s, c, d (`IpCompoundVector.cpp::Dot`).
668        let mut b_vec: Vec<Number> = Vec::with_capacity(n as usize);
669        for k in 0..n {
670            let dot = mvx.get_vector(k).dot(rhs.rhs_x)
671                + mvs.get_vector(k).dot(rhs.rhs_s)
672                + mvc.get_vector(k).dot(rhs.rhs_c)
673                + mvd.get_vector(k).dot(rhs.rhs_d);
674            b_vec.push(dot);
675        }
676        let space_b = DenseVectorSpace::new(n);
677        let mut b = space_b.make_new_dense();
678        b.set_values(&b_vec);
679        // Apply J⁻¹ J⁻ᵀ in-place.
680        j.cholesky_solve_vector(&mut b);
681        // sol += sign · M · b  per block.
682        mvx.mult_vector(sign, &b, 1.0, sol.sol_x);
683        mvs.mult_vector(sign, &b, 1.0, sol.sol_s);
684        mvc.mult_vector(sign, &b, 1.0, sol.sol_c);
685        mvd.mult_vector(sign, &b, 1.0, sol.sol_d);
686    }
687}
688
689#[cfg(test)]
690mod tests {
691    use super::*;
692    use pounce_linalg::dense_vector::DenseVectorSpace;
693    use pounce_linalg::low_rank_update_sym_matrix::LowRankUpdateSymMatrixSpace;
694    use std::cell::Cell;
695
696    /// Diagonal-solve stub: pretends the augmented system is just
697    /// `(W + δ_x I) · sol_x = rhs_x` with `m_c = m_d = n_s = 0`. Reads
698    /// `coeffs.w` as a `DiagMatrix` (i.e. the wdiag we built) and does
699    /// a per-element divide. Plenty for the SMW test fixture.
700    struct DiagInner {
701        calls: Cell<usize>,
702    }
703    impl AugSystemSolver for DiagInner {
704        fn provides_inertia(&self) -> bool {
705            true
706        }
707        fn number_of_neg_evals(&self) -> Index {
708            0
709        }
710        fn increase_quality(&mut self) -> bool {
711            true
712        }
713        fn last_solve_status(&self) -> ESymSolverStatus {
714            ESymSolverStatus::Success
715        }
716        fn solve(
717            &mut self,
718            coeffs: &AugSysCoeffs<'_>,
719            rhs: &AugSysRhs<'_>,
720            sol: &mut AugSysSol<'_>,
721            _check_neg_evals: bool,
722            _num_neg_evals: Index,
723        ) -> ESymSolverStatus {
724            self.calls.set(self.calls.get() + 1);
725            let wdiag = coeffs
726                .w
727                .expect("DiagInner requires W")
728                .as_any()
729                .downcast_ref::<DiagMatrix>()
730                .expect("DiagInner requires W to be a DiagMatrix");
731            let diag_rc = wdiag.get_diag().expect("Wdiag has no diag set").clone();
732            let diag = downcast_dense(diag_rc.as_ref()).expanded_values();
733            let rhs_x = downcast_dense(rhs.rhs_x).expanded_values();
734            let dx_vals: Option<Vec<Number>> =
735                coeffs.d_x.map(|d| downcast_dense(d).expanded_values());
736            let mut out = vec![0.0; rhs_x.len()];
737            for i in 0..rhs_x.len() {
738                let dx_i = match &dx_vals {
739                    Some(v) => v[i],
740                    None => 0.0,
741                };
742                let denom = diag[i] + dx_i + coeffs.delta_x;
743                out[i] = rhs_x[i] / denom;
744            }
745            let sol_x_dv = sol
746                .sol_x
747                .as_any_mut()
748                .downcast_mut::<DenseVector>()
749                .unwrap();
750            sol_x_dv.set_values(&out);
751            // Other blocks stay zero — fixture has m_c = m_d = n_s = 0.
752            ESymSolverStatus::Success
753        }
754    }
755
756    fn dvec(space: &Rc<DenseVectorSpace>, vals: &[Number]) -> DenseVector {
757        let mut v = space.make_new_dense();
758        v.set_values(vals);
759        v
760    }
761
762    fn dvec_rc(space: &Rc<DenseVectorSpace>, vals: &[Number]) -> Rc<DenseVector> {
763        Rc::new(dvec(space, vals))
764    }
765
766    #[test]
767    fn smw_recovers_low_rank_inverse() {
768        // 1×1 system: W = b0 + v² (v ≠ 0); δ_x = 0.
769        // Direct: sol = rhs / (b0 + v²).
770        // SMW:    inner solves with diag b0 → sol_diag = rhs/b0;
771        //         correction recovers rhs/(b0 + v²).
772        let space_x = DenseVectorSpace::new(1);
773        let space_zero = DenseVectorSpace::new(0);
774        let lr_space = LowRankUpdateSymMatrixSpace::new(1, None, false);
775        let mut lr = lr_space.make_new_low_rank();
776        let b0_rc: Rc<dyn Vector> = dvec_rc(&space_x, &[2.0]);
777        lr.set_diag(b0_rc);
778        let v_space = MultiVectorMatrixSpace::new(1, Rc::clone(&space_x));
779        let mut v_mvm = v_space.make_new_multi_vector();
780        v_mvm.set_vector(0, dvec_rc(&space_x, &[3.0]) as Rc<dyn Vector>);
781        lr.set_v(Rc::new(v_mvm));
782        let lr_rc: Rc<LowRankUpdateSymMatrix> = Rc::new(lr);
783
784        let mut solver = LowRankAugSystemSolver::new(Box::new(DiagInner {
785            calls: Cell::new(0),
786        }));
787
788        // Empty Jacobians.
789        let j_c_space = pounce_linalg::dense_gen_matrix::DenseGenMatrixSpace::new(0, 1);
790        let j_d_space = pounce_linalg::dense_gen_matrix::DenseGenMatrixSpace::new(0, 1);
791        let j_c = j_c_space.make_new_dense_gen();
792        let j_d = j_d_space.make_new_dense_gen();
793
794        let coeffs = AugSysCoeffs {
795            w: Some(lr_rc.as_ref() as &dyn SymMatrix),
796            w_factor: 1.0,
797            d_x: None,
798            delta_x: 0.0,
799            d_s: None,
800            delta_s: 0.0,
801            j_c: &j_c as &dyn Matrix,
802            d_c: None,
803            delta_c: 0.0,
804            j_d: &j_d as &dyn Matrix,
805            d_d: None,
806            delta_d: 0.0,
807        };
808
809        let rhs_x = dvec(&space_x, &[5.0]);
810        let rhs_s = dvec(&space_zero, &[]);
811        let rhs_c = dvec(&space_zero, &[]);
812        let rhs_d = dvec(&space_zero, &[]);
813        let rhs = AugSysRhs {
814            rhs_x: &rhs_x,
815            rhs_s: &rhs_s,
816            rhs_c: &rhs_c,
817            rhs_d: &rhs_d,
818        };
819        let mut sol_x = dvec(&space_x, &[0.0]);
820        let mut sol_s = dvec(&space_zero, &[]);
821        let mut sol_c = dvec(&space_zero, &[]);
822        let mut sol_d = dvec(&space_zero, &[]);
823        let mut sol = AugSysSol {
824            sol_x: &mut sol_x,
825            sol_s: &mut sol_s,
826            sol_c: &mut sol_c,
827            sol_d: &mut sol_d,
828        };
829        let status = solver.solve(&coeffs, &rhs, &mut sol, false, 0);
830        assert_eq!(status, ESymSolverStatus::Success);
831        // Expected: 5 / (2 + 9) = 5/11.
832        let got = sol_x.expanded_values()[0];
833        let want = 5.0 / 11.0;
834        assert!((got - want).abs() < 1e-12, "got {} want {}", got, want);
835    }
836
837    #[test]
838    fn smw_with_u_only_applies_positive_correction() {
839        // 1×1 system: W = b0 − u² (low-rank *negative* update).
840        // Direct: sol = rhs / (b0 − u²).
841        let space_x = DenseVectorSpace::new(1);
842        let space_zero = DenseVectorSpace::new(0);
843        let lr_space = LowRankUpdateSymMatrixSpace::new(1, None, false);
844        let mut lr = lr_space.make_new_low_rank();
845        lr.set_diag(dvec_rc(&space_x, &[5.0]));
846        let u_space = MultiVectorMatrixSpace::new(1, Rc::clone(&space_x));
847        let mut u_mvm = u_space.make_new_multi_vector();
848        u_mvm.set_vector(0, dvec_rc(&space_x, &[1.5]) as Rc<dyn Vector>);
849        lr.set_u(Rc::new(u_mvm));
850        let lr_rc: Rc<LowRankUpdateSymMatrix> = Rc::new(lr);
851
852        let mut solver = LowRankAugSystemSolver::new(Box::new(DiagInner {
853            calls: Cell::new(0),
854        }));
855
856        let j_c_space = pounce_linalg::dense_gen_matrix::DenseGenMatrixSpace::new(0, 1);
857        let j_d_space = pounce_linalg::dense_gen_matrix::DenseGenMatrixSpace::new(0, 1);
858        let j_c = j_c_space.make_new_dense_gen();
859        let j_d = j_d_space.make_new_dense_gen();
860
861        let coeffs = AugSysCoeffs {
862            w: Some(lr_rc.as_ref() as &dyn SymMatrix),
863            w_factor: 1.0,
864            d_x: None,
865            delta_x: 0.0,
866            d_s: None,
867            delta_s: 0.0,
868            j_c: &j_c as &dyn Matrix,
869            d_c: None,
870            delta_c: 0.0,
871            j_d: &j_d as &dyn Matrix,
872            d_d: None,
873            delta_d: 0.0,
874        };
875
876        let rhs_x = dvec(&space_x, &[7.0]);
877        let rhs_s = dvec(&space_zero, &[]);
878        let rhs_c = dvec(&space_zero, &[]);
879        let rhs_d = dvec(&space_zero, &[]);
880        let rhs = AugSysRhs {
881            rhs_x: &rhs_x,
882            rhs_s: &rhs_s,
883            rhs_c: &rhs_c,
884            rhs_d: &rhs_d,
885        };
886        let mut sol_x = dvec(&space_x, &[0.0]);
887        let mut sol_s = dvec(&space_zero, &[]);
888        let mut sol_c = dvec(&space_zero, &[]);
889        let mut sol_d = dvec(&space_zero, &[]);
890        let mut sol = AugSysSol {
891            sol_x: &mut sol_x,
892            sol_s: &mut sol_s,
893            sol_c: &mut sol_c,
894            sol_d: &mut sol_d,
895        };
896        let status = solver.solve(&coeffs, &rhs, &mut sol, false, 0);
897        assert_eq!(status, ESymSolverStatus::Success);
898        // Expected: 7 / (5 − 2.25) = 7 / 2.75.
899        let got = sol_x.expanded_values()[0];
900        let want = 7.0 / 2.75;
901        assert!((got - want).abs() < 1e-12, "got {} want {}", got, want);
902    }
903
904    #[test]
905    fn smw_with_v_and_u_combines_corrections() {
906        // 1×1 system: W = b0 + v² − u² (rank-2 update). Solve checks
907        // both correction passes compose correctly.
908        let space_x = DenseVectorSpace::new(1);
909        let space_zero = DenseVectorSpace::new(0);
910        let lr_space = LowRankUpdateSymMatrixSpace::new(1, None, false);
911        let mut lr = lr_space.make_new_low_rank();
912        lr.set_diag(dvec_rc(&space_x, &[10.0]));
913        let v_space = MultiVectorMatrixSpace::new(1, Rc::clone(&space_x));
914        let mut v_mvm = v_space.make_new_multi_vector();
915        v_mvm.set_vector(0, dvec_rc(&space_x, &[2.0]) as Rc<dyn Vector>);
916        lr.set_v(Rc::new(v_mvm));
917        let u_space = MultiVectorMatrixSpace::new(1, Rc::clone(&space_x));
918        let mut u_mvm = u_space.make_new_multi_vector();
919        u_mvm.set_vector(0, dvec_rc(&space_x, &[1.0]) as Rc<dyn Vector>);
920        lr.set_u(Rc::new(u_mvm));
921        let lr_rc: Rc<LowRankUpdateSymMatrix> = Rc::new(lr);
922
923        let mut solver = LowRankAugSystemSolver::new(Box::new(DiagInner {
924            calls: Cell::new(0),
925        }));
926
927        let j_c_space = pounce_linalg::dense_gen_matrix::DenseGenMatrixSpace::new(0, 1);
928        let j_d_space = pounce_linalg::dense_gen_matrix::DenseGenMatrixSpace::new(0, 1);
929        let j_c = j_c_space.make_new_dense_gen();
930        let j_d = j_d_space.make_new_dense_gen();
931
932        let coeffs = AugSysCoeffs {
933            w: Some(lr_rc.as_ref() as &dyn SymMatrix),
934            w_factor: 1.0,
935            d_x: None,
936            delta_x: 0.0,
937            d_s: None,
938            delta_s: 0.0,
939            j_c: &j_c as &dyn Matrix,
940            d_c: None,
941            delta_c: 0.0,
942            j_d: &j_d as &dyn Matrix,
943            d_d: None,
944            delta_d: 0.0,
945        };
946
947        let rhs_x = dvec(&space_x, &[1.0]);
948        let rhs_s = dvec(&space_zero, &[]);
949        let rhs_c = dvec(&space_zero, &[]);
950        let rhs_d = dvec(&space_zero, &[]);
951        let rhs = AugSysRhs {
952            rhs_x: &rhs_x,
953            rhs_s: &rhs_s,
954            rhs_c: &rhs_c,
955            rhs_d: &rhs_d,
956        };
957        let mut sol_x = dvec(&space_x, &[0.0]);
958        let mut sol_s = dvec(&space_zero, &[]);
959        let mut sol_c = dvec(&space_zero, &[]);
960        let mut sol_d = dvec(&space_zero, &[]);
961        let mut sol = AugSysSol {
962            sol_x: &mut sol_x,
963            sol_s: &mut sol_s,
964            sol_c: &mut sol_c,
965            sol_d: &mut sol_d,
966        };
967        let status = solver.solve(&coeffs, &rhs, &mut sol, false, 0);
968        assert_eq!(status, ESymSolverStatus::Success);
969        // Expected: 1 / (10 + 4 − 1) = 1/13.
970        let got = sol_x.expanded_values()[0];
971        let want = 1.0 / 13.0;
972        assert!((got - want).abs() < 1e-12, "got {} want {}", got, want);
973    }
974
975    #[test]
976    fn unchanged_coeffs_skip_rebuild_after_first_call() {
977        let mut lr_solver = LowRankAugSystemSolver::new(Box::new(DiagInner {
978            calls: Cell::new(0),
979        }));
980        let space_x = DenseVectorSpace::new(1);
981        let space_zero = DenseVectorSpace::new(0);
982        let lr_space = LowRankUpdateSymMatrixSpace::new(1, None, false);
983        let mut lr = lr_space.make_new_low_rank();
984        lr.set_diag(dvec_rc(&space_x, &[2.0]));
985        let lr_rc: Rc<LowRankUpdateSymMatrix> = Rc::new(lr);
986        let j_c_space = pounce_linalg::dense_gen_matrix::DenseGenMatrixSpace::new(0, 1);
987        let j_d_space = pounce_linalg::dense_gen_matrix::DenseGenMatrixSpace::new(0, 1);
988        let j_c = j_c_space.make_new_dense_gen();
989        let j_d = j_d_space.make_new_dense_gen();
990        let coeffs = AugSysCoeffs {
991            w: Some(lr_rc.as_ref() as &dyn SymMatrix),
992            w_factor: 1.0,
993            d_x: None,
994            delta_x: 0.001,
995            d_s: None,
996            delta_s: 0.0,
997            j_c: &j_c as &dyn Matrix,
998            d_c: None,
999            delta_c: 0.0,
1000            j_d: &j_d as &dyn Matrix,
1001            d_d: None,
1002            delta_d: 0.0,
1003        };
1004        let rhs_x = dvec(&space_x, &[1.0]);
1005        let rhs_zero = dvec(&space_zero, &[]);
1006        let rhs = AugSysRhs {
1007            rhs_x: &rhs_x,
1008            rhs_s: &rhs_zero,
1009            rhs_c: &rhs_zero,
1010            rhs_d: &rhs_zero,
1011        };
1012        let mut sol_x = dvec(&space_x, &[0.0]);
1013        let mut sol_z1 = dvec(&space_zero, &[]);
1014        let mut sol_z2 = dvec(&space_zero, &[]);
1015        let mut sol_z3 = dvec(&space_zero, &[]);
1016        {
1017            let mut sol = AugSysSol {
1018                sol_x: &mut sol_x,
1019                sol_s: &mut sol_z1,
1020                sol_c: &mut sol_z2,
1021                sol_d: &mut sol_z3,
1022            };
1023            lr_solver.solve(&coeffs, &rhs, &mut sol, false, 0);
1024        }
1025        // Same coeffs → cache reports no change.
1026        assert!(!lr_solver.augmented_system_requires_change(&coeffs));
1027    }
1028}