pounce_algorithm/kkt/aug_system_solver.rs
1//! Augmented-system solver trait — port of `IpAugSystemSolver.hpp`.
2//!
3//! Solves the symmetric saddle-point system
4//!
5//! ```text
6//! [ W·factor + Σ_x + δ_x I 0 J_c^T J_d^T ] [ dx ] [ rx ]
7//! [ 0 Σ_s + δ_s I 0 -I ] [ ds ] = [ rs ]
8//! [ J_c 0 -Σ_c-δ_c 0 ] [ dyc] [ rc ]
9//! [ J_d -I 0 -Σ_d-δ_d] [ dyd] [ rd ]
10//! ```
11//!
12//! See `KKT_SYSTEM.md` §3 for the sign convention. `Σ_x = D_x`, `Σ_s
13//! = D_s`, `Σ_c = D_c`, `Σ_d = D_d` are the diagonal weights pulled
14//! from `IpoptCalculatedQuantities`. Any of the `D_*` may be `None`,
15//! interpreted as zero. `delta_*` are the perturbations driven by the
16//! `PerturbationHandler`.
17
18use pounce_common::timing::TimingStatistics;
19use pounce_common::types::{Index, Number};
20use pounce_linalg::{Matrix, SymMatrix, Vector};
21use pounce_linsol::ESymSolverStatus;
22use std::rc::Rc;
23
24/// Bundle of the matrices/vectors that define one augmented-system
25/// instance. Lives only for the duration of the call. Mirrors the
26/// long argument list of upstream `AugSystemSolver::Solve`.
27pub struct AugSysCoeffs<'a> {
28 /// Hessian-of-Lagrangian block. `None` means W = 0 (used by
29 /// `LeastSquareMults` and the resto-NLP equality multiplier
30 /// estimate).
31 pub w: Option<&'a dyn SymMatrix>,
32 /// Multiplier on `W` (typically 1.0; restoration uses ζ).
33 pub w_factor: Number,
34 /// `D_x`, the (1,1) primal weight diagonal. `None` means zero.
35 pub d_x: Option<&'a dyn Vector>,
36 pub delta_x: Number,
37 /// `D_s`, the (2,2) slack weight diagonal. `None` means zero.
38 pub d_s: Option<&'a dyn Vector>,
39 pub delta_s: Number,
40 /// Equality-constraint Jacobian, `m_c × n_x`.
41 pub j_c: &'a dyn Matrix,
42 /// `D_c`, the (3,3) diagonal weight. `None` means zero. Goes in
43 /// with a *negative* sign, matching upstream.
44 pub d_c: Option<&'a dyn Vector>,
45 pub delta_c: Number,
46 /// Inequality-constraint Jacobian, `m_d × n_x`.
47 pub j_d: &'a dyn Matrix,
48 /// `D_d`, the (4,4) diagonal weight. `None` means zero. Goes in
49 /// with a *negative* sign, matching upstream.
50 pub d_d: Option<&'a dyn Vector>,
51 pub delta_d: Number,
52}
53
54/// Right-hand sides for one solve. All four slices are required;
55/// upstream always provides all four (even if some are zero).
56pub struct AugSysRhs<'a> {
57 pub rhs_x: &'a dyn Vector,
58 pub rhs_s: &'a dyn Vector,
59 pub rhs_c: &'a dyn Vector,
60 pub rhs_d: &'a dyn Vector,
61}
62
63/// Solution slots, written in place. Must already be sized to match
64/// the corresponding RHS dim.
65pub struct AugSysSol<'a> {
66 pub sol_x: &'a mut dyn Vector,
67 pub sol_s: &'a mut dyn Vector,
68 pub sol_c: &'a mut dyn Vector,
69 pub sol_d: &'a mut dyn Vector,
70}
71
72/// Trait surface mirroring `Ipopt::AugSystemSolver`.
73pub trait AugSystemSolver {
74 /// Whether the underlying linear solver reports inertia.
75 fn provides_inertia(&self) -> bool;
76
77 /// Number of negative eigenvalues observed in the most recent
78 /// factorization. Caller checks `provides_inertia()` first.
79 fn number_of_neg_evals(&self) -> Index;
80
81 /// Ask the underlying solver for higher-quality pivoting.
82 fn increase_quality(&mut self) -> bool;
83
84 /// Status of the most recent `solve` call.
85 fn last_solve_status(&self) -> ESymSolverStatus;
86
87 /// Install the shared per-solve `TimingStatistics` so the
88 /// linear-system factor/back-solve calls are attributed to
89 /// `linear_system_factorization` / `linear_system_back_solve`.
90 /// Default impl is a no-op (timing disabled); the standard
91 /// solver overrides to record both fields, and composite solvers
92 /// (LowRank) forward to their inner solver.
93 fn set_timing_stats(&mut self, _timing: Rc<TimingStatistics>) {}
94
95 /// Install the shared per-solve diagnostics state so KKT-dump
96 /// sites can consult per-iter gating. Default impl is a no-op
97 /// (diagnostics disabled); the standard solver overrides to wire
98 /// in the dump path.
99 fn set_diagnostics(&mut self, _diag: Rc<pounce_common::diagnostics::DiagnosticsState>) {}
100
101 /// One factor + back-substitution for the full 4×4 block system.
102 /// `check_neg_evals=true` asks the linsol to verify that the
103 /// observed inertia equals `num_neg_evals`; on mismatch the
104 /// status is `WrongInertia` and the solution is left untouched.
105 fn solve(
106 &mut self,
107 coeffs: &AugSysCoeffs<'_>,
108 rhs: &AugSysRhs<'_>,
109 sol: &mut AugSysSol<'_>,
110 check_neg_evals: bool,
111 num_neg_evals: Index,
112 ) -> ESymSolverStatus;
113
114 /// Back-substitution only, reusing the factorization from the most
115 /// recent successful `solve`. Caller must guarantee the augmented
116 /// matrix is byte-identical to that solve (same W, J_c, J_d, all
117 /// diagonals, all perturbations, same pivot tolerance). Used by
118 /// `PdFullSpaceSolver`'s iterative-refinement loop and same-matrix
119 /// fast path to avoid the per-iter MA57BD refactor that dominates
120 /// pounce-ma57 wall time on long-iter problems (e.g. cont5_2_4_l
121 /// drops from 97s → ~30s once refactor-per-refinement is gone).
122 ///
123 /// Default impl falls through to `solve` (correct but slow);
124 /// `StdAugSystemSolver` overrides to skip `refill_values` and pass
125 /// `new_matrix=false` to the linear solver.
126 fn resolve(
127 &mut self,
128 coeffs: &AugSysCoeffs<'_>,
129 rhs: &AugSysRhs<'_>,
130 sol: &mut AugSysSol<'_>,
131 ) -> ESymSolverStatus {
132 self.solve(coeffs, rhs, sol, false, 0)
133 }
134
135 /// Solve the same KKT system for `nrhs` right-hand sides. Default
136 /// impl loops [`solve`]; concrete backends override only when they
137 /// can amortize factorization across calls. Mirrors upstream's
138 /// `AugSystemSolver::MultiSolve` (`IpAugSystemSolver.hpp:113-150`).
139 ///
140 /// `rhs_list` and `sol_list` must have the same length; each pair
141 /// describes one independent solve. The same `coeffs` are used for
142 /// every column.
143 fn multi_solve(
144 &mut self,
145 coeffs: &AugSysCoeffs<'_>,
146 rhs_list: &[&AugSysRhs<'_>],
147 sol_list: &mut [&mut AugSysSol<'_>],
148 check_neg_evals: bool,
149 num_neg_evals: Index,
150 ) -> ESymSolverStatus {
151 debug_assert_eq!(rhs_list.len(), sol_list.len());
152 for (rhs, sol) in rhs_list.iter().zip(sol_list.iter_mut()) {
153 let status = self.solve(coeffs, rhs, *sol, check_neg_evals, num_neg_evals);
154 if status != ESymSolverStatus::Success {
155 return status;
156 }
157 }
158 ESymSolverStatus::Success
159 }
160
161 /// Back-substitution only against the cached factor for
162 /// `nrhs` right-hand sides, packed in **column-major** layout in
163 /// `packed_rhs`. Each column has length `dim = n_x + n_s + n_y_c +
164 /// n_y_d` (the aug-system dim — z/v blocks are not part of this
165 /// path; callers expand them via `expand_bound_multipliers` after
166 /// the fact). Solutions overwrite `packed_rhs` in place.
167 ///
168 /// Returns `None` when the backend does not support this fast
169 /// path; the caller should then fall back to a per-RHS loop over
170 /// [`resolve`]. The contract on `coeffs` and `have_factor` matches
171 /// [`resolve`]'s.
172 ///
173 /// `StdAugSystemSolver` overrides this to forward to
174 /// `pounce_linsol::TSymLinearSolver::multi_solve` with `nrhs > 1`,
175 /// which lets the underlying backend (FERAL / MA57 / LAPACK)
176 /// amortize per-call setup and, where supported, block the
177 /// triangular solves. Used by `pounce-sensitivity` for the JaxProblem
178 /// `jacrev` backward, where every cotangent re-solves against the
179 /// same converged factor (pounce#77 follow-up).
180 fn try_resolve_many_flat(
181 &mut self,
182 _coeffs: &AugSysCoeffs<'_>,
183 _packed_rhs: &mut [Number],
184 _nrhs: usize,
185 ) -> Option<ESymSolverStatus> {
186 None
187 }
188}