Skip to main content

pounce_algorithm/kkt/
aug_system_solver.rs

1//! Augmented-system solver trait — port of `IpAugSystemSolver.hpp`.
2//!
3//! Solves the symmetric saddle-point system
4//!
5//! ```text
6//!   [ W·factor + Σ_x + δ_x I       0          J_c^T   J_d^T ] [ dx ]   [ rx ]
7//!   [          0           Σ_s + δ_s I        0       -I    ] [ ds ] = [ rs ]
8//!   [         J_c                  0       -Σ_c-δ_c    0    ] [ dyc]   [ rc ]
9//!   [         J_d                 -I            0   -Σ_d-δ_d] [ dyd]   [ rd ]
10//! ```
11//!
12//! See `KKT_SYSTEM.md` §3 for the sign convention. `Σ_x = D_x`, `Σ_s
13//! = D_s`, `Σ_c = D_c`, `Σ_d = D_d` are the diagonal weights pulled
14//! from `IpoptCalculatedQuantities`. Any of the `D_*` may be `None`,
15//! interpreted as zero. `delta_*` are the perturbations driven by the
16//! `PerturbationHandler`.
17
18use pounce_common::timing::TimingStatistics;
19use pounce_common::types::{Index, Number};
20use pounce_linalg::{Matrix, SymMatrix, Vector};
21use pounce_linsol::{ESymSolverStatus, FactorPattern};
22use std::rc::Rc;
23
24/// Bundle of the matrices/vectors that define one augmented-system
25/// instance. Lives only for the duration of the call. Mirrors the
26/// long argument list of upstream `AugSystemSolver::Solve`.
27pub struct AugSysCoeffs<'a> {
28    /// Hessian-of-Lagrangian block. `None` means W = 0 (used by
29    /// `LeastSquareMults` and the resto-NLP equality multiplier
30    /// estimate).
31    pub w: Option<&'a dyn SymMatrix>,
32    /// Multiplier on `W` (typically 1.0; restoration uses ζ).
33    pub w_factor: Number,
34    /// `D_x`, the (1,1) primal weight diagonal. `None` means zero.
35    pub d_x: Option<&'a dyn Vector>,
36    pub delta_x: Number,
37    /// `D_s`, the (2,2) slack weight diagonal. `None` means zero.
38    pub d_s: Option<&'a dyn Vector>,
39    pub delta_s: Number,
40    /// Equality-constraint Jacobian, `m_c × n_x`.
41    pub j_c: &'a dyn Matrix,
42    /// `D_c`, the (3,3) diagonal weight. `None` means zero. Goes in
43    /// with a *negative* sign, matching upstream.
44    pub d_c: Option<&'a dyn Vector>,
45    pub delta_c: Number,
46    /// Inequality-constraint Jacobian, `m_d × n_x`.
47    pub j_d: &'a dyn Matrix,
48    /// `D_d`, the (4,4) diagonal weight. `None` means zero. Goes in
49    /// with a *negative* sign, matching upstream.
50    pub d_d: Option<&'a dyn Vector>,
51    pub delta_d: Number,
52}
53
54/// Right-hand sides for one solve. All four slices are required;
55/// upstream always provides all four (even if some are zero).
56pub struct AugSysRhs<'a> {
57    pub rhs_x: &'a dyn Vector,
58    pub rhs_s: &'a dyn Vector,
59    pub rhs_c: &'a dyn Vector,
60    pub rhs_d: &'a dyn Vector,
61}
62
63/// Solution slots, written in place. Must already be sized to match
64/// the corresponding RHS dim.
65pub struct AugSysSol<'a> {
66    pub sol_x: &'a mut dyn Vector,
67    pub sol_s: &'a mut dyn Vector,
68    pub sol_c: &'a mut dyn Vector,
69    pub sol_d: &'a mut dyn Vector,
70}
71
72/// Trait surface mirroring `Ipopt::AugSystemSolver`.
73pub trait AugSystemSolver {
74    /// Whether the underlying linear solver reports inertia.
75    fn provides_inertia(&self) -> bool;
76
77    /// Number of negative eigenvalues observed in the most recent
78    /// factorization. Caller checks `provides_inertia()` first.
79    fn number_of_neg_evals(&self) -> Index;
80
81    /// Dimension of the assembled augmented (KKT) system. Used by the
82    /// interactive debugger to report inertia; default 0 for backends
83    /// that don't track it.
84    fn system_dim(&self) -> Index {
85        0
86    }
87
88    /// Triplets of the assembled KKT matrix `(dim, irn, jcn, vals)`
89    /// (1-based lower triangle), for the debugger's `viz kkt`. Default
90    /// `None` for backends that don't expose them.
91    fn kkt_triplets(&self) -> Option<(Index, Vec<Index>, Vec<Index>, Vec<Number>)> {
92        None
93    }
94
95    /// The `LDLᵀ` factor pattern (and optionally values) of the most
96    /// recent factorization, for the debugger's `viz L`. Default `None`.
97    fn l_factor(&self, _want_values: bool) -> Option<FactorPattern> {
98        None
99    }
100
101    /// Ask the underlying solver for higher-quality pivoting.
102    fn increase_quality(&mut self) -> bool;
103
104    /// Status of the most recent `solve` call.
105    fn last_solve_status(&self) -> ESymSolverStatus;
106
107    /// Install the shared per-solve `TimingStatistics` so the
108    /// linear-system factor/back-solve calls are attributed to
109    /// `linear_system_factorization` / `linear_system_back_solve`.
110    /// Default impl is a no-op (timing disabled); the standard
111    /// solver overrides to record both fields, and composite solvers
112    /// (LowRank) forward to their inner solver.
113    fn set_timing_stats(&mut self, _timing: Rc<TimingStatistics>) {}
114
115    /// Install the shared per-solve diagnostics state so KKT-dump
116    /// sites can consult per-iter gating. Default impl is a no-op
117    /// (diagnostics disabled); the standard solver overrides to wire
118    /// in the dump path.
119    fn set_diagnostics(&mut self, _diag: Rc<pounce_common::diagnostics::DiagnosticsState>) {}
120
121    /// One factor + back-substitution for the full 4×4 block system.
122    /// `check_neg_evals=true` asks the linsol to verify that the
123    /// observed inertia equals `num_neg_evals`; on mismatch the
124    /// status is `WrongInertia` and the solution is left untouched.
125    fn solve(
126        &mut self,
127        coeffs: &AugSysCoeffs<'_>,
128        rhs: &AugSysRhs<'_>,
129        sol: &mut AugSysSol<'_>,
130        check_neg_evals: bool,
131        num_neg_evals: Index,
132    ) -> ESymSolverStatus;
133
134    /// Back-substitution only, reusing the factorization from the most
135    /// recent successful `solve`. Caller must guarantee the augmented
136    /// matrix is byte-identical to that solve (same W, J_c, J_d, all
137    /// diagonals, all perturbations, same pivot tolerance). Used by
138    /// `PdFullSpaceSolver`'s iterative-refinement loop and same-matrix
139    /// fast path to avoid the per-iter MA57BD refactor that dominates
140    /// pounce-ma57 wall time on long-iter problems (e.g. cont5_2_4_l
141    /// drops from 97s → ~30s once refactor-per-refinement is gone).
142    ///
143    /// Default impl falls through to `solve` (correct but slow);
144    /// `StdAugSystemSolver` overrides to skip `refill_values` and pass
145    /// `new_matrix=false` to the linear solver.
146    fn resolve(
147        &mut self,
148        coeffs: &AugSysCoeffs<'_>,
149        rhs: &AugSysRhs<'_>,
150        sol: &mut AugSysSol<'_>,
151    ) -> ESymSolverStatus {
152        self.solve(coeffs, rhs, sol, false, 0)
153    }
154
155    /// Solve the same KKT system for `nrhs` right-hand sides. Default
156    /// impl loops [`solve`]; concrete backends override only when they
157    /// can amortize factorization across calls. Mirrors upstream's
158    /// `AugSystemSolver::MultiSolve` (`IpAugSystemSolver.hpp:113-150`).
159    ///
160    /// `rhs_list` and `sol_list` must have the same length; each pair
161    /// describes one independent solve. The same `coeffs` are used for
162    /// every column.
163    fn multi_solve(
164        &mut self,
165        coeffs: &AugSysCoeffs<'_>,
166        rhs_list: &[&AugSysRhs<'_>],
167        sol_list: &mut [&mut AugSysSol<'_>],
168        check_neg_evals: bool,
169        num_neg_evals: Index,
170    ) -> ESymSolverStatus {
171        debug_assert_eq!(rhs_list.len(), sol_list.len());
172        for (rhs, sol) in rhs_list.iter().zip(sol_list.iter_mut()) {
173            let status = self.solve(coeffs, rhs, *sol, check_neg_evals, num_neg_evals);
174            if status != ESymSolverStatus::Success {
175                return status;
176            }
177        }
178        ESymSolverStatus::Success
179    }
180
181    /// Back-substitution only against the cached factor for
182    /// `nrhs` right-hand sides, packed in **column-major** layout in
183    /// `packed_rhs`. Each column has length `dim = n_x + n_s + n_y_c +
184    /// n_y_d` (the aug-system dim — z/v blocks are not part of this
185    /// path; callers expand them via `expand_bound_multipliers` after
186    /// the fact). Solutions overwrite `packed_rhs` in place.
187    ///
188    /// Returns `None` when the backend does not support this fast
189    /// path; the caller should then fall back to a per-RHS loop over
190    /// [`resolve`]. The contract on `coeffs` and `have_factor` matches
191    /// [`resolve`]'s.
192    ///
193    /// `StdAugSystemSolver` overrides this to forward to
194    /// `pounce_linsol::TSymLinearSolver::multi_solve` with `nrhs > 1`,
195    /// which lets the underlying backend (FERAL / MA57 / LAPACK)
196    /// amortize per-call setup and, where supported, block the
197    /// triangular solves. Used by `pounce-sensitivity` for the JaxProblem
198    /// `jacrev` backward, where every cotangent re-solves against the
199    /// same converged factor (pounce#77 follow-up).
200    fn try_resolve_many_flat(
201        &mut self,
202        _coeffs: &AugSysCoeffs<'_>,
203        _packed_rhs: &mut [Number],
204        _nrhs: usize,
205    ) -> Option<ESymSolverStatus> {
206        None
207    }
208}