Skip to main content

pounce_algorithm/kkt/
aug_system_solver.rs

1//! Augmented-system solver trait — port of `IpAugSystemSolver.hpp`.
2//!
3//! Solves the symmetric saddle-point system
4//!
5//! ```text
6//!   [ W·factor + Σ_x + δ_x I       0          J_c^T   J_d^T ] [ dx ]   [ rx ]
7//!   [          0           Σ_s + δ_s I        0       -I    ] [ ds ] = [ rs ]
8//!   [         J_c                  0       -Σ_c-δ_c    0    ] [ dyc]   [ rc ]
9//!   [         J_d                 -I            0   -Σ_d-δ_d] [ dyd]   [ rd ]
10//! ```
11//!
12//! See `KKT_SYSTEM.md` §3 for the sign convention. `Σ_x = D_x`, `Σ_s
13//! = D_s`, `Σ_c = D_c`, `Σ_d = D_d` are the diagonal weights pulled
14//! from `IpoptCalculatedQuantities`. Any of the `D_*` may be `None`,
15//! interpreted as zero. `delta_*` are the perturbations driven by the
16//! `PerturbationHandler`.
17
18use pounce_common::timing::TimingStatistics;
19use pounce_common::types::{Index, Number};
20use pounce_linalg::{Matrix, SymMatrix, Vector};
21use pounce_linsol::ESymSolverStatus;
22use std::rc::Rc;
23
24/// Bundle of the matrices/vectors that define one augmented-system
25/// instance. Lives only for the duration of the call. Mirrors the
26/// long argument list of upstream `AugSystemSolver::Solve`.
27pub struct AugSysCoeffs<'a> {
28    /// Hessian-of-Lagrangian block. `None` means W = 0 (used by
29    /// `LeastSquareMults` and the resto-NLP equality multiplier
30    /// estimate).
31    pub w: Option<&'a dyn SymMatrix>,
32    /// Multiplier on `W` (typically 1.0; restoration uses ζ).
33    pub w_factor: Number,
34    /// `D_x`, the (1,1) primal weight diagonal. `None` means zero.
35    pub d_x: Option<&'a dyn Vector>,
36    pub delta_x: Number,
37    /// `D_s`, the (2,2) slack weight diagonal. `None` means zero.
38    pub d_s: Option<&'a dyn Vector>,
39    pub delta_s: Number,
40    /// Equality-constraint Jacobian, `m_c × n_x`.
41    pub j_c: &'a dyn Matrix,
42    /// `D_c`, the (3,3) diagonal weight. `None` means zero. Goes in
43    /// with a *negative* sign, matching upstream.
44    pub d_c: Option<&'a dyn Vector>,
45    pub delta_c: Number,
46    /// Inequality-constraint Jacobian, `m_d × n_x`.
47    pub j_d: &'a dyn Matrix,
48    /// `D_d`, the (4,4) diagonal weight. `None` means zero. Goes in
49    /// with a *negative* sign, matching upstream.
50    pub d_d: Option<&'a dyn Vector>,
51    pub delta_d: Number,
52}
53
54/// Right-hand sides for one solve. All four slices are required;
55/// upstream always provides all four (even if some are zero).
56pub struct AugSysRhs<'a> {
57    pub rhs_x: &'a dyn Vector,
58    pub rhs_s: &'a dyn Vector,
59    pub rhs_c: &'a dyn Vector,
60    pub rhs_d: &'a dyn Vector,
61}
62
63/// Solution slots, written in place. Must already be sized to match
64/// the corresponding RHS dim.
65pub struct AugSysSol<'a> {
66    pub sol_x: &'a mut dyn Vector,
67    pub sol_s: &'a mut dyn Vector,
68    pub sol_c: &'a mut dyn Vector,
69    pub sol_d: &'a mut dyn Vector,
70}
71
72/// Trait surface mirroring `Ipopt::AugSystemSolver`.
73pub trait AugSystemSolver {
74    /// Whether the underlying linear solver reports inertia.
75    fn provides_inertia(&self) -> bool;
76
77    /// Number of negative eigenvalues observed in the most recent
78    /// factorization. Caller checks `provides_inertia()` first.
79    fn number_of_neg_evals(&self) -> Index;
80
81    /// Ask the underlying solver for higher-quality pivoting.
82    fn increase_quality(&mut self) -> bool;
83
84    /// Status of the most recent `solve` call.
85    fn last_solve_status(&self) -> ESymSolverStatus;
86
87    /// Install the shared per-solve `TimingStatistics` so the
88    /// linear-system factor/back-solve calls are attributed to
89    /// `linear_system_factorization` / `linear_system_back_solve`.
90    /// Default impl is a no-op (timing disabled); the standard
91    /// solver overrides to record both fields, and composite solvers
92    /// (LowRank) forward to their inner solver.
93    fn set_timing_stats(&mut self, _timing: Rc<TimingStatistics>) {}
94
95    /// Install the shared per-solve diagnostics state so KKT-dump
96    /// sites can consult per-iter gating. Default impl is a no-op
97    /// (diagnostics disabled); the standard solver overrides to wire
98    /// in the dump path.
99    fn set_diagnostics(&mut self, _diag: Rc<pounce_common::diagnostics::DiagnosticsState>) {}
100
101    /// One factor + back-substitution for the full 4×4 block system.
102    /// `check_neg_evals=true` asks the linsol to verify that the
103    /// observed inertia equals `num_neg_evals`; on mismatch the
104    /// status is `WrongInertia` and the solution is left untouched.
105    fn solve(
106        &mut self,
107        coeffs: &AugSysCoeffs<'_>,
108        rhs: &AugSysRhs<'_>,
109        sol: &mut AugSysSol<'_>,
110        check_neg_evals: bool,
111        num_neg_evals: Index,
112    ) -> ESymSolverStatus;
113
114    /// Back-substitution only, reusing the factorization from the most
115    /// recent successful `solve`. Caller must guarantee the augmented
116    /// matrix is byte-identical to that solve (same W, J_c, J_d, all
117    /// diagonals, all perturbations, same pivot tolerance). Used by
118    /// `PdFullSpaceSolver`'s iterative-refinement loop and same-matrix
119    /// fast path to avoid the per-iter MA57BD refactor that dominates
120    /// pounce-ma57 wall time on long-iter problems (e.g. cont5_2_4_l
121    /// drops from 97s → ~30s once refactor-per-refinement is gone).
122    ///
123    /// Default impl falls through to `solve` (correct but slow);
124    /// `StdAugSystemSolver` overrides to skip `refill_values` and pass
125    /// `new_matrix=false` to the linear solver.
126    fn resolve(
127        &mut self,
128        coeffs: &AugSysCoeffs<'_>,
129        rhs: &AugSysRhs<'_>,
130        sol: &mut AugSysSol<'_>,
131    ) -> ESymSolverStatus {
132        self.solve(coeffs, rhs, sol, false, 0)
133    }
134
135    /// Solve the same KKT system for `nrhs` right-hand sides. Default
136    /// impl loops [`solve`]; concrete backends override only when they
137    /// can amortize factorization across calls. Mirrors upstream's
138    /// `AugSystemSolver::MultiSolve` (`IpAugSystemSolver.hpp:113-150`).
139    ///
140    /// `rhs_list` and `sol_list` must have the same length; each pair
141    /// describes one independent solve. The same `coeffs` are used for
142    /// every column.
143    fn multi_solve(
144        &mut self,
145        coeffs: &AugSysCoeffs<'_>,
146        rhs_list: &[&AugSysRhs<'_>],
147        sol_list: &mut [&mut AugSysSol<'_>],
148        check_neg_evals: bool,
149        num_neg_evals: Index,
150    ) -> ESymSolverStatus {
151        debug_assert_eq!(rhs_list.len(), sol_list.len());
152        for (rhs, sol) in rhs_list.iter().zip(sol_list.iter_mut()) {
153            let status = self.solve(coeffs, rhs, *sol, check_neg_evals, num_neg_evals);
154            if status != ESymSolverStatus::Success {
155                return status;
156            }
157        }
158        ESymSolverStatus::Success
159    }
160
161    /// Back-substitution only against the cached factor for
162    /// `nrhs` right-hand sides, packed in **column-major** layout in
163    /// `packed_rhs`. Each column has length `dim = n_x + n_s + n_y_c +
164    /// n_y_d` (the aug-system dim — z/v blocks are not part of this
165    /// path; callers expand them via `expand_bound_multipliers` after
166    /// the fact). Solutions overwrite `packed_rhs` in place.
167    ///
168    /// Returns `None` when the backend does not support this fast
169    /// path; the caller should then fall back to a per-RHS loop over
170    /// [`resolve`]. The contract on `coeffs` and `have_factor` matches
171    /// [`resolve`]'s.
172    ///
173    /// `StdAugSystemSolver` overrides this to forward to
174    /// `pounce_linsol::TSymLinearSolver::multi_solve` with `nrhs > 1`,
175    /// which lets the underlying backend (FERAL / MA57 / LAPACK)
176    /// amortize per-call setup and, where supported, block the
177    /// triangular solves. Used by `pounce-sensitivity` for the JaxProblem
178    /// `jacrev` backward, where every cotangent re-solves against the
179    /// same converged factor (pounce#77 follow-up).
180    fn try_resolve_many_flat(
181        &mut self,
182        _coeffs: &AugSysCoeffs<'_>,
183        _packed_rhs: &mut [Number],
184        _nrhs: usize,
185    ) -> Option<ESymSolverStatus> {
186        None
187    }
188}