pounce_sensitivity/algorithm_backsolver.rs
1//! `PdSensBacksolver` — `SensBacksolver` adapter over the converged
2//! `PdFullSpaceSolver` from `pounce-algorithm`.
3//!
4//! This is the Phase B.2 piece tracked in
5//! [pounce#16](https://github.com/jkitchin/pounce/issues/16): it lets
6//! `pounce-sensitivity` drive backsolves against the real converged
7//! KKT factor, replacing the synthetic [`crate::DenseLuBacksolver`]
8//! used by Phase B.1 unit tests.
9//!
10//! # Use
11//!
12//! 1. Register an `on_converged` callback on `IpoptApplication` via
13//! [`pounce_algorithm::application::IpoptApplication::set_on_converged`].
14//! 2. Inside the callback, build a `PdSensBacksolver` from the four
15//! handles passed in (`data`, `cq`, `nlp`, `&mut pd_solver`).
16//! 3. Hand it to [`crate::SensApplication`] / a `SensStepCalc` /
17//! [`crate::compute_reduced_hessian`] like any other
18//! [`SensBacksolver`].
19//!
20//! Upstream `SensSimpleBacksolver`
21//! ([`ref/Ipopt/contrib/sIPOPT/src/SensSimpleBacksolver.cpp`](../../../ref/Ipopt/contrib/sIPOPT/src/SensSimpleBacksolver.cpp))
22//! is the analogous wrapper around `IpoptCalculatedQuantities` +
23//! `PDSystemSolver` upstream.
24//!
25//! # Flat-slice ↔ `IteratesVector` mapping
26//!
27//! The full primal-dual state of pounce's IPM is the eight-block
28//! compound `(x, s, λ_c, λ_d, z_l, z_u, v_l, v_u)` (see
29//! [`pounce_algorithm::iterates_vector::IteratesVector`]). This
30//! adapter packs / unpacks the flat slices that
31//! [`crate::SensBacksolver`] takes as the concatenation
32//! `x || s || λ_c || λ_d || z_l || z_u || v_l || v_u`, mirroring
33//! upstream's `CompoundVector` layout (`IpCompoundVector.hpp`).
34//!
35//! # Reference
36//!
37//! Pirnay, H.; López-Negrete, R.; Biegler, L. T. (2012). *Optimal
38//! sensitivity based on IPOPT*. Mathematical Programming Computation,
39//! **4**(4), 307–331. DOI:
40//! [10.1007/s12532-012-0043-2](https://doi.org/10.1007/s12532-012-0043-2).
41//! Verified via Crossref on 2026-05-13.
42
43use std::cell::RefCell;
44use std::rc::Rc;
45
46use pounce_algorithm::ipopt_cq::IpoptCqHandle;
47use pounce_algorithm::ipopt_data::IpoptDataHandle;
48use pounce_algorithm::iterates_vector::{IteratesVector, IteratesVectorMut};
49use pounce_algorithm::kkt::pd_full_space_solver::PdFullSpaceSolver;
50use pounce_common::types::Number;
51use pounce_linalg::dense_vector::DenseVector;
52use pounce_nlp::ipopt_nlp::IpoptNlp;
53
54use crate::backsolver::SensBacksolver;
55
56/// Adapter from `PdFullSpaceSolver` to [`SensBacksolver`]. Holds
57/// owning clones of the four pieces of the algorithm's converged
58/// state, plus the 8-block iterate template used to allocate fresh
59/// RHS / LHS vectors.
60///
61/// The PD solver lives behind an `Rc<RefCell<…>>` because
62/// [`SensBacksolver::solve`] is `&self` but the upstream signature
63/// for `PdFullSpaceSolver::solve` is `&mut self` (it caches the
64/// last-solve dependency tags and the augsys-improved flag). The
65/// `RefCell` is single-thread-only, single-borrow, exactly matching
66/// the call pattern from `pounce-sensitivity`'s pipeline.
67///
68/// Owning (rather than borrowing) the four handles is what lets a
69/// `PdSensBacksolver` outlive the `on_converged` callback frame —
70/// required by the public `Solver` session API in `pounce-algorithm`,
71/// which retains the backsolver for repeated `parametric_step` /
72/// `kkt_solve` / `compute_reduced_hessian` calls after the IPM has
73/// returned. The data, cq, and nlp handles are already
74/// `Rc<RefCell<…>>` cheap-clone handles upstream, so this carries no
75/// allocation overhead.
76#[derive(Clone)]
77pub struct PdSensBacksolver {
78 /// Shared, interior-mutable handle to the converged PD solver.
79 /// Cloned from `PdSearchDirCalc::pd_solver_rc()` at construction.
80 pd: Rc<RefCell<PdFullSpaceSolver>>,
81 data: IpoptDataHandle,
82 cq: IpoptCqHandle,
83 nlp: Rc<RefCell<dyn IpoptNlp>>,
84 /// Block dimensions in `(x, s, y_c, y_d, z_l, z_u, v_l, v_u)` order.
85 dims: [usize; 8],
86 /// 8-block prototype used to mint fresh vectors with the same
87 /// `VectorSpace`s as the converged iterate; cloned from
88 /// `data.borrow().curr`.
89 template: IteratesVector,
90}
91
92impl PdSensBacksolver {
93 /// Construct from the four handles handed in by the `on_converged`
94 /// callback. Returns `Err(())` if `data` has no `curr` (i.e. the
95 /// algorithm never reached an iterate — should not happen on
96 /// `SolveSucceeded`).
97 pub fn new(
98 data: &IpoptDataHandle,
99 cq: &IpoptCqHandle,
100 nlp: &Rc<RefCell<dyn IpoptNlp>>,
101 pd: Rc<RefCell<PdFullSpaceSolver>>,
102 ) -> Result<Self, ()> {
103 let curr = data.borrow().curr.clone().ok_or(())?;
104 let dims = [
105 curr.x.dim() as usize,
106 curr.s.dim() as usize,
107 curr.y_c.dim() as usize,
108 curr.y_d.dim() as usize,
109 curr.z_l.dim() as usize,
110 curr.z_u.dim() as usize,
111 curr.v_l.dim() as usize,
112 curr.v_u.dim() as usize,
113 ];
114 Ok(Self {
115 pd,
116 data: Rc::clone(data),
117 cq: Rc::clone(cq),
118 nlp: Rc::clone(nlp),
119 dims,
120 template: curr,
121 })
122 }
123
124 /// Block dimensions of the compound KKT vector at convergence, in
125 /// `(x, s, y_c, y_d, z_l, z_u, v_l, v_u)` order. Sum equals
126 /// [`SensBacksolver::dim`]. Useful when a caller needs to compute
127 /// the flat offset of a non-x block (e.g. `n_x + n_s` for the
128 /// start of the equality-multiplier `y_c` block).
129 pub fn block_dims(&self) -> [usize; 8] {
130 self.dims
131 }
132
133 /// Cumulative block offsets: `offset(i)` is the start index of
134 /// block `i` in the flat slice.
135 fn offsets(&self) -> [usize; 9] {
136 let mut o = [0usize; 9];
137 for i in 0..8 {
138 o[i + 1] = o[i] + self.dims[i];
139 }
140 o
141 }
142
143 /// Pack a flat slice into a freshly-allocated `IteratesVectorMut`
144 /// shaped like the converged iterate.
145 fn pack(&self, flat: &[Number]) -> Result<IteratesVectorMut, ()> {
146 let mut out = self.template.make_new_zeroed();
147 let off = self.offsets();
148 let blocks: [&mut Box<dyn pounce_linalg::vector::Vector>; 8] = [
149 &mut out.x,
150 &mut out.s,
151 &mut out.y_c,
152 &mut out.y_d,
153 &mut out.z_l,
154 &mut out.z_u,
155 &mut out.v_l,
156 &mut out.v_u,
157 ];
158 for (i, blk) in blocks.into_iter().enumerate() {
159 let slice = &flat[off[i]..off[i + 1]];
160 let dv = blk.as_any_mut().downcast_mut::<DenseVector>().ok_or(())?;
161 dv.set_values(slice);
162 }
163 Ok(out)
164 }
165
166 /// Read an `IteratesVectorMut` into a flat slice. Uses
167 /// [`DenseVector::expanded_values`] rather than `values()` so
168 /// blocks that the IPM left in homogeneous-scalar form (typical
169 /// for empty z_l/z_u/v_l/v_u when the TNLP has no bounds) are
170 /// materialized rather than panicking.
171 fn unpack(&self, iv: &IteratesVectorMut, out: &mut [Number]) -> Result<(), ()> {
172 let off = self.offsets();
173 let blocks: [&Box<dyn pounce_linalg::vector::Vector>; 8] = [
174 &iv.x, &iv.s, &iv.y_c, &iv.y_d, &iv.z_l, &iv.z_u, &iv.v_l, &iv.v_u,
175 ];
176 for (i, blk) in blocks.into_iter().enumerate() {
177 let dst = &mut out[off[i]..off[i + 1]];
178 if dst.is_empty() {
179 continue;
180 }
181 let dv = (**blk).as_any().downcast_ref::<DenseVector>().ok_or(())?;
182 let ev = dv.expanded_values();
183 dst.copy_from_slice(&ev);
184 }
185 Ok(())
186 }
187}
188
189impl PdSensBacksolver {
190 /// Batched-RHS back-solve over the held factor. `rhs_flat` and
191 /// `lhs_flat` are row-major `(n_rhs, dim)` buffers. Equivalent to
192 /// looping [`SensBacksolver::solve`] over each row but reuses one
193 /// frozen `IteratesVector` for the RHS and one `IteratesVectorMut`
194 /// for the result across all `n_rhs` calls into
195 /// [`PdFullSpaceSolver::solve`]. The pack step writes into the
196 /// existing `DenseVector` storage via `Rc::get_mut` +
197 /// `set_values`, and the unpack step reads it back via `values()`
198 /// /`scalar()` — skipping the per-call 8-block `make_new_zeroed`
199 /// (Box alloc) in `pack` and the per-block `expanded_values()` Vec
200 /// alloc in `unpack` that otherwise dominate the held-factor
201 /// back-solve cost under `jax.jacrev` over a JaxProblem solve
202 /// (pounce#77 follow-up).
203 ///
204 /// The matrix and perturbation state inside `PdFullSpaceSolver`
205 /// are unchanged across calls, so each iteration hits the cached
206 /// fast path in `solve_once` (`uptodate && !pretend_singular`).
207 pub fn solve_many(&self, rhs_flat: &[Number], lhs_flat: &mut [Number], n_rhs: usize) -> bool {
208 let total = self.dim();
209 if rhs_flat.len() != n_rhs * total || lhs_flat.len() != n_rhs * total {
210 return false;
211 }
212 if n_rhs == 0 {
213 return true;
214 }
215 let off = self.offsets();
216
217 // Tier 1: fully-inline flat-slice path. `PdFullSpaceSolver::
218 // solve_many_cached_flat` downcasts the slack / z / v vectors to
219 // `DenseVector` and the bound-expansion matrices to
220 // `ExpansionMatrix` once at the top, then runs Phase 1 / Phase 3
221 // as raw scatter-add / divide loops on flat slices with no dyn
222 // dispatch in the per-RHS inner loops. Returns `None` if a
223 // downcast fails (homogeneous-on-non-empty block, unusual matrix
224 // type) — we fall to Tier 2.
225 {
226 let mut pd_ref = self.pd.borrow_mut();
227 let fast_flat = pd_ref.solve_many_cached_flat(
228 &self.data, &self.cq, &self.nlp, n_rhs, rhs_flat, lhs_flat, self.dims,
229 );
230 match fast_flat {
231 Some(true) => return true,
232 Some(false) => return false,
233 None => { /* fall through to Tier 2 */ }
234 }
235 }
236
237 // Tier 2: closure-based cached-factor path. Same single
238 // back-substitution through the linsol, but Phase 1 / Phase 3
239 // go through `dyn Vector` / `dyn Matrix` ops on a per-RHS
240 // `IteratesVectorMut`. Slower than Tier 1 but correct for
241 // homogeneous DenseVectors and non-`ExpansionMatrix` bound
242 // expansions.
243 {
244 let mut pd_ref = self.pd.borrow_mut();
245 let fast = pd_ref.solve_many_cached(
246 &self.data,
247 &self.cq,
248 &self.nlp,
249 n_rhs,
250 |k, iv| {
251 let row = &rhs_flat[k * total..(k + 1) * total];
252 let _ = write_rhs_box(&mut iv.x, &row[off[0]..off[1]])
253 && write_rhs_box(&mut iv.s, &row[off[1]..off[2]])
254 && write_rhs_box(&mut iv.y_c, &row[off[2]..off[3]])
255 && write_rhs_box(&mut iv.y_d, &row[off[3]..off[4]])
256 && write_rhs_box(&mut iv.z_l, &row[off[4]..off[5]])
257 && write_rhs_box(&mut iv.z_u, &row[off[5]..off[6]])
258 && write_rhs_box(&mut iv.v_l, &row[off[6]..off[7]])
259 && write_rhs_box(&mut iv.v_u, &row[off[7]..off[8]]);
260 },
261 |k, iv| {
262 let row = &mut lhs_flat[k * total..(k + 1) * total];
263 let _ = read_res_block(&*iv.x, &mut row[off[0]..off[1]])
264 && read_res_block(&*iv.s, &mut row[off[1]..off[2]])
265 && read_res_block(&*iv.y_c, &mut row[off[2]..off[3]])
266 && read_res_block(&*iv.y_d, &mut row[off[3]..off[4]])
267 && read_res_block(&*iv.z_l, &mut row[off[4]..off[5]])
268 && read_res_block(&*iv.z_u, &mut row[off[5]..off[6]])
269 && read_res_block(&*iv.v_l, &mut row[off[6]..off[7]])
270 && read_res_block(&*iv.v_u, &mut row[off[7]..off[8]]);
271 },
272 );
273 match fast {
274 Some(true) => return true,
275 Some(false) => return false,
276 None => { /* fall through to per-RHS loop */ }
277 }
278 }
279
280 // Per-RHS fallback: reuse one frozen rhs and one mut sol across
281 // all n_rhs `solve` calls.
282 let rhs_mut0 = self.template.make_new_zeroed();
283 let mut rhs_iv = rhs_mut0.freeze();
284 let mut res_iv = self.template.make_new_zeroed();
285
286 let mut pd_ref = self.pd.borrow_mut();
287 for k in 0..n_rhs {
288 let rhs_row = &rhs_flat[k * total..(k + 1) * total];
289 let lhs_row = &mut lhs_flat[k * total..(k + 1) * total];
290
291 if !write_rhs_block(&mut rhs_iv.x, &rhs_row[off[0]..off[1]])
292 || !write_rhs_block(&mut rhs_iv.s, &rhs_row[off[1]..off[2]])
293 || !write_rhs_block(&mut rhs_iv.y_c, &rhs_row[off[2]..off[3]])
294 || !write_rhs_block(&mut rhs_iv.y_d, &rhs_row[off[3]..off[4]])
295 || !write_rhs_block(&mut rhs_iv.z_l, &rhs_row[off[4]..off[5]])
296 || !write_rhs_block(&mut rhs_iv.z_u, &rhs_row[off[5]..off[6]])
297 || !write_rhs_block(&mut rhs_iv.v_l, &rhs_row[off[6]..off[7]])
298 || !write_rhs_block(&mut rhs_iv.v_u, &rhs_row[off[7]..off[8]])
299 {
300 return false;
301 }
302
303 let ok = pd_ref.solve(
304 &self.data,
305 &self.cq,
306 &self.nlp,
307 1.0,
308 0.0,
309 &rhs_iv,
310 &mut res_iv,
311 /* allow_inexact = */ true,
312 /* improve_solution = */ false,
313 );
314 if !ok {
315 return false;
316 }
317
318 if !read_res_block(&*res_iv.x, &mut lhs_row[off[0]..off[1]])
319 || !read_res_block(&*res_iv.s, &mut lhs_row[off[1]..off[2]])
320 || !read_res_block(&*res_iv.y_c, &mut lhs_row[off[2]..off[3]])
321 || !read_res_block(&*res_iv.y_d, &mut lhs_row[off[3]..off[4]])
322 || !read_res_block(&*res_iv.z_l, &mut lhs_row[off[4]..off[5]])
323 || !read_res_block(&*res_iv.z_u, &mut lhs_row[off[5]..off[6]])
324 || !read_res_block(&*res_iv.v_l, &mut lhs_row[off[6]..off[7]])
325 || !read_res_block(&*res_iv.v_u, &mut lhs_row[off[7]..off[8]])
326 {
327 return false;
328 }
329 }
330 true
331 }
332}
333
334/// Write `slice` into the `DenseVector` behind `b` in place. Used by
335/// the fast path's `write_rhs` closure, where the new
336/// `PdFullSpaceSolver::solve_many_cached` API hands back an
337/// `IteratesVectorMut` (Box-backed blocks).
338fn write_rhs_box(b: &mut Box<dyn pounce_linalg::vector::Vector>, slice: &[Number]) -> bool {
339 if slice.is_empty() {
340 return true;
341 }
342 let Some(dv) = b.as_any_mut().downcast_mut::<DenseVector>() else {
343 return false;
344 };
345 dv.set_values(slice);
346 true
347}
348
349/// Write `slice` into the `DenseVector` behind `rc` in place. Returns
350/// `false` if the Rc is unexpectedly shared (would indicate a bug in
351/// `PdFullSpaceSolver::solve`'s borrow discipline — it should never
352/// `Rc::clone` from the rhs vector) or if the block is not a
353/// `DenseVector`.
354fn write_rhs_block(rc: &mut Rc<dyn pounce_linalg::vector::Vector>, slice: &[Number]) -> bool {
355 if slice.is_empty() {
356 return true;
357 }
358 let Some(v) = Rc::get_mut(rc) else {
359 return false;
360 };
361 let Some(dv) = v.as_any_mut().downcast_mut::<DenseVector>() else {
362 return false;
363 };
364 dv.set_values(slice);
365 true
366}
367
368/// Read the `DenseVector` behind `blk` into `dst`. Handles the
369/// homogeneous case (empty z/v blocks for a TNLP with no bounds) by
370/// broadcasting the scalar rather than calling `expanded_values()`,
371/// which would allocate a fresh `Vec<Number>` every call.
372fn read_res_block(blk: &dyn pounce_linalg::vector::Vector, dst: &mut [Number]) -> bool {
373 if dst.is_empty() {
374 return true;
375 }
376 let Some(dv) = blk.as_any().downcast_ref::<DenseVector>() else {
377 return false;
378 };
379 if dv.is_homogeneous() {
380 let s = dv.scalar();
381 for x in dst.iter_mut() {
382 *x = s;
383 }
384 } else {
385 dst.copy_from_slice(dv.values());
386 }
387 true
388}
389
390impl SensBacksolver for PdSensBacksolver {
391 fn dim(&self) -> usize {
392 self.dims.iter().sum()
393 }
394
395 fn solve(&self, rhs: &[Number], lhs: &mut [Number]) -> bool {
396 let total = self.dim();
397 if rhs.len() != total || lhs.len() != total {
398 return false;
399 }
400 // Pack rhs into block form.
401 let rhs_mut = match self.pack(rhs) {
402 Ok(v) => v,
403 Err(()) => return false,
404 };
405 let rhs_iv = rhs_mut.freeze();
406 // Fresh result slot, zeroed.
407 let mut res_iv = self.template.make_new_zeroed();
408
409 // K · lhs = rhs ⇒ solve(α=1, β=0, rhs, res) writes
410 // res = K⁻¹ · rhs.
411 //
412 // `allow_inexact=true` mirrors upstream sIPOPT's
413 // `SensSimpleBacksolver`: skip `PdFullSpaceSolver`'s iterative-
414 // refinement loop and accept the first back-solve against the
415 // held factor. The IPM-level refinement (`min_refinement_steps
416 // = 1`, residual_ratio_max = 1e-10`) is there to clean up
417 // numerical noise during forward IPM steps; for the held-factor
418 // back-solve used by sens / JaxProblem bwd, it ~doubles the
419 // per-call cost and produces gains that are below `tol`. Under
420 // `jax.jacrev` over a JaxProblem solve this dominates the wall
421 // time at moderate `n+m` (pounce#77 follow-up).
422 let ok = {
423 let mut pd_ref = self.pd.borrow_mut();
424 pd_ref.solve(
425 &self.data,
426 &self.cq,
427 &self.nlp,
428 1.0,
429 0.0,
430 &rhs_iv,
431 &mut res_iv,
432 /* allow_inexact = */ true,
433 /* improve_solution = */ false,
434 )
435 };
436 if !ok {
437 return false;
438 }
439 self.unpack(&res_iv, lhs).is_ok()
440 }
441}