Skip to main content

pounce_sensitivity/
algorithm_backsolver.rs

1//! `PdSensBacksolver` — `SensBacksolver` adapter over the converged
2//! `PdFullSpaceSolver` from `pounce-algorithm`.
3//!
4//! This is the Phase B.2 piece tracked in
5//! [pounce#16](https://github.com/jkitchin/pounce/issues/16): it lets
6//! `pounce-sensitivity` drive backsolves against the real converged
7//! KKT factor, replacing the synthetic [`crate::DenseLuBacksolver`]
8//! used by Phase B.1 unit tests.
9//!
10//! # Use
11//!
12//! 1. Register an `on_converged` callback on `IpoptApplication` via
13//!    [`pounce_algorithm::application::IpoptApplication::set_on_converged`].
14//! 2. Inside the callback, build a `PdSensBacksolver` from the four
15//!    handles passed in (`data`, `cq`, `nlp`, `&mut pd_solver`).
16//! 3. Hand it to [`crate::SensApplication`] / a `SensStepCalc` /
17//!    [`crate::compute_reduced_hessian`] like any other
18//!    [`SensBacksolver`].
19//!
20//! Upstream `SensSimpleBacksolver`
21//! ([`ref/Ipopt/contrib/sIPOPT/src/SensSimpleBacksolver.cpp`](../../../ref/Ipopt/contrib/sIPOPT/src/SensSimpleBacksolver.cpp))
22//! is the analogous wrapper around `IpoptCalculatedQuantities` +
23//! `PDSystemSolver` upstream.
24//!
25//! # Flat-slice ↔ `IteratesVector` mapping
26//!
27//! The full primal-dual state of pounce's IPM is the eight-block
28//! compound `(x, s, λ_c, λ_d, z_l, z_u, v_l, v_u)` (see
29//! [`pounce_algorithm::iterates_vector::IteratesVector`]). This
30//! adapter packs / unpacks the flat slices that
31//! [`crate::SensBacksolver`] takes as the concatenation
32//! `x || s || λ_c || λ_d || z_l || z_u || v_l || v_u`, mirroring
33//! upstream's `CompoundVector` layout (`IpCompoundVector.hpp`).
34//!
35//! # Reference
36//!
37//! Pirnay, H.; López-Negrete, R.; Biegler, L. T. (2012). *Optimal
38//! sensitivity based on IPOPT*. Mathematical Programming Computation,
39//! **4**(4), 307–331. DOI:
40//! [10.1007/s12532-012-0043-2](https://doi.org/10.1007/s12532-012-0043-2).
41//! Verified via Crossref on 2026-05-13.
42
43use std::cell::RefCell;
44use std::rc::Rc;
45
46use pounce_algorithm::ipopt_cq::IpoptCqHandle;
47use pounce_algorithm::ipopt_data::IpoptDataHandle;
48use pounce_algorithm::iterates_vector::{IteratesVector, IteratesVectorMut};
49use pounce_algorithm::kkt::pd_full_space_solver::PdFullSpaceSolver;
50use pounce_common::types::Number;
51use pounce_linalg::dense_vector::DenseVector;
52use pounce_nlp::ipopt_nlp::IpoptNlp;
53
54use crate::backsolver::SensBacksolver;
55
56/// Adapter from `PdFullSpaceSolver` to [`SensBacksolver`]. Holds
57/// owning clones of the four pieces of the algorithm's converged
58/// state, plus the 8-block iterate template used to allocate fresh
59/// RHS / LHS vectors.
60///
61/// The PD solver lives behind an `Rc<RefCell<…>>` because
62/// [`SensBacksolver::solve`] is `&self` but the upstream signature
63/// for `PdFullSpaceSolver::solve` is `&mut self` (it caches the
64/// last-solve dependency tags and the augsys-improved flag). The
65/// `RefCell` is single-thread-only, single-borrow, exactly matching
66/// the call pattern from `pounce-sensitivity`'s pipeline.
67///
68/// Owning (rather than borrowing) the four handles is what lets a
69/// `PdSensBacksolver` outlive the `on_converged` callback frame —
70/// required by the public `Solver` session API in `pounce-algorithm`,
71/// which retains the backsolver for repeated `parametric_step` /
72/// `kkt_solve` / `compute_reduced_hessian` calls after the IPM has
73/// returned. The data, cq, and nlp handles are already
74/// `Rc<RefCell<…>>` cheap-clone handles upstream, so this carries no
75/// allocation overhead.
76#[derive(Clone)]
77pub struct PdSensBacksolver {
78    /// Shared, interior-mutable handle to the converged PD solver.
79    /// Cloned from `PdSearchDirCalc::pd_solver_rc()` at construction.
80    pd: Rc<RefCell<PdFullSpaceSolver>>,
81    data: IpoptDataHandle,
82    cq: IpoptCqHandle,
83    nlp: Rc<RefCell<dyn IpoptNlp>>,
84    /// Block dimensions in `(x, s, y_c, y_d, z_l, z_u, v_l, v_u)` order.
85    dims: [usize; 8],
86    /// 8-block prototype used to mint fresh vectors with the same
87    /// `VectorSpace`s as the converged iterate; cloned from
88    /// `data.borrow().curr`.
89    template: IteratesVector,
90}
91
92impl PdSensBacksolver {
93    /// Construct from the four handles handed in by the `on_converged`
94    /// callback. Returns `Err(())` if `data` has no `curr` (i.e. the
95    /// algorithm never reached an iterate — should not happen on
96    /// `SolveSucceeded`).
97    pub fn new(
98        data: &IpoptDataHandle,
99        cq: &IpoptCqHandle,
100        nlp: &Rc<RefCell<dyn IpoptNlp>>,
101        pd: Rc<RefCell<PdFullSpaceSolver>>,
102    ) -> Result<Self, ()> {
103        let curr = data.borrow().curr.clone().ok_or(())?;
104        let dims = [
105            curr.x.dim() as usize,
106            curr.s.dim() as usize,
107            curr.y_c.dim() as usize,
108            curr.y_d.dim() as usize,
109            curr.z_l.dim() as usize,
110            curr.z_u.dim() as usize,
111            curr.v_l.dim() as usize,
112            curr.v_u.dim() as usize,
113        ];
114        Ok(Self {
115            pd,
116            data: Rc::clone(data),
117            cq: Rc::clone(cq),
118            nlp: Rc::clone(nlp),
119            dims,
120            template: curr,
121        })
122    }
123
124    /// Block dimensions of the compound KKT vector at convergence, in
125    /// `(x, s, y_c, y_d, z_l, z_u, v_l, v_u)` order. Sum equals
126    /// [`SensBacksolver::dim`]. Useful when a caller needs to compute
127    /// the flat offset of a non-x block (e.g. `n_x + n_s` for the
128    /// start of the equality-multiplier `y_c` block).
129    pub fn block_dims(&self) -> [usize; 8] {
130        self.dims
131    }
132
133    /// Cumulative block offsets: `offset(i)` is the start index of
134    /// block `i` in the flat slice.
135    fn offsets(&self) -> [usize; 9] {
136        let mut o = [0usize; 9];
137        for i in 0..8 {
138            o[i + 1] = o[i] + self.dims[i];
139        }
140        o
141    }
142
143    /// Pack a flat slice into a freshly-allocated `IteratesVectorMut`
144    /// shaped like the converged iterate.
145    fn pack(&self, flat: &[Number]) -> Result<IteratesVectorMut, ()> {
146        let mut out = self.template.make_new_zeroed();
147        let off = self.offsets();
148        let blocks: [&mut Box<dyn pounce_linalg::vector::Vector>; 8] = [
149            &mut out.x,
150            &mut out.s,
151            &mut out.y_c,
152            &mut out.y_d,
153            &mut out.z_l,
154            &mut out.z_u,
155            &mut out.v_l,
156            &mut out.v_u,
157        ];
158        for (i, blk) in blocks.into_iter().enumerate() {
159            let slice = &flat[off[i]..off[i + 1]];
160            let dv = blk.as_any_mut().downcast_mut::<DenseVector>().ok_or(())?;
161            dv.set_values(slice);
162        }
163        Ok(out)
164    }
165
166    /// Read an `IteratesVectorMut` into a flat slice. Uses
167    /// [`DenseVector::expanded_values`] rather than `values()` so
168    /// blocks that the IPM left in homogeneous-scalar form (typical
169    /// for empty z_l/z_u/v_l/v_u when the TNLP has no bounds) are
170    /// materialized rather than panicking.
171    fn unpack(&self, iv: &IteratesVectorMut, out: &mut [Number]) -> Result<(), ()> {
172        let off = self.offsets();
173        let blocks: [&Box<dyn pounce_linalg::vector::Vector>; 8] = [
174            &iv.x, &iv.s, &iv.y_c, &iv.y_d, &iv.z_l, &iv.z_u, &iv.v_l, &iv.v_u,
175        ];
176        for (i, blk) in blocks.into_iter().enumerate() {
177            let dst = &mut out[off[i]..off[i + 1]];
178            if dst.is_empty() {
179                continue;
180            }
181            let dv = (**blk).as_any().downcast_ref::<DenseVector>().ok_or(())?;
182            let ev = dv.expanded_values();
183            dst.copy_from_slice(&ev);
184        }
185        Ok(())
186    }
187}
188
189impl PdSensBacksolver {
190    /// Batched-RHS back-solve over the held factor. `rhs_flat` and
191    /// `lhs_flat` are row-major `(n_rhs, dim)` buffers. Equivalent to
192    /// looping [`SensBacksolver::solve`] over each row but reuses one
193    /// frozen `IteratesVector` for the RHS and one `IteratesVectorMut`
194    /// for the result across all `n_rhs` calls into
195    /// [`PdFullSpaceSolver::solve`]. The pack step writes into the
196    /// existing `DenseVector` storage via `Rc::get_mut` +
197    /// `set_values`, and the unpack step reads it back via `values()`
198    /// /`scalar()` — skipping the per-call 8-block `make_new_zeroed`
199    /// (Box alloc) in `pack` and the per-block `expanded_values()` Vec
200    /// alloc in `unpack` that otherwise dominate the held-factor
201    /// back-solve cost under `jax.jacrev` over a JaxProblem solve
202    /// (pounce#77 follow-up).
203    ///
204    /// The matrix and perturbation state inside `PdFullSpaceSolver`
205    /// are unchanged across calls, so each iteration hits the cached
206    /// fast path in `solve_once` (`uptodate && !pretend_singular`).
207    pub fn solve_many(&self, rhs_flat: &[Number], lhs_flat: &mut [Number], n_rhs: usize) -> bool {
208        let total = self.dim();
209        if rhs_flat.len() != n_rhs * total || lhs_flat.len() != n_rhs * total {
210            return false;
211        }
212        if n_rhs == 0 {
213            return true;
214        }
215        let off = self.offsets();
216
217        // Tier 1: fully-inline flat-slice path. `PdFullSpaceSolver::
218        // solve_many_cached_flat` downcasts the slack / z / v vectors to
219        // `DenseVector` and the bound-expansion matrices to
220        // `ExpansionMatrix` once at the top, then runs Phase 1 / Phase 3
221        // as raw scatter-add / divide loops on flat slices with no dyn
222        // dispatch in the per-RHS inner loops. Returns `None` if a
223        // downcast fails (homogeneous-on-non-empty block, unusual matrix
224        // type) — we fall to Tier 2.
225        {
226            let mut pd_ref = self.pd.borrow_mut();
227            let fast_flat = pd_ref.solve_many_cached_flat(
228                &self.data, &self.cq, &self.nlp, n_rhs, rhs_flat, lhs_flat, self.dims,
229            );
230            match fast_flat {
231                Some(true) => return true,
232                Some(false) => return false,
233                None => { /* fall through to Tier 2 */ }
234            }
235        }
236
237        // Tier 2: closure-based cached-factor path. Same single
238        // back-substitution through the linsol, but Phase 1 / Phase 3
239        // go through `dyn Vector` / `dyn Matrix` ops on a per-RHS
240        // `IteratesVectorMut`. Slower than Tier 1 but correct for
241        // homogeneous DenseVectors and non-`ExpansionMatrix` bound
242        // expansions.
243        {
244            let mut pd_ref = self.pd.borrow_mut();
245            let fast = pd_ref.solve_many_cached(
246                &self.data,
247                &self.cq,
248                &self.nlp,
249                n_rhs,
250                |k, iv| {
251                    let row = &rhs_flat[k * total..(k + 1) * total];
252                    let _ = write_rhs_box(&mut iv.x, &row[off[0]..off[1]])
253                        && write_rhs_box(&mut iv.s, &row[off[1]..off[2]])
254                        && write_rhs_box(&mut iv.y_c, &row[off[2]..off[3]])
255                        && write_rhs_box(&mut iv.y_d, &row[off[3]..off[4]])
256                        && write_rhs_box(&mut iv.z_l, &row[off[4]..off[5]])
257                        && write_rhs_box(&mut iv.z_u, &row[off[5]..off[6]])
258                        && write_rhs_box(&mut iv.v_l, &row[off[6]..off[7]])
259                        && write_rhs_box(&mut iv.v_u, &row[off[7]..off[8]]);
260                },
261                |k, iv| {
262                    let row = &mut lhs_flat[k * total..(k + 1) * total];
263                    let _ = read_res_block(&*iv.x, &mut row[off[0]..off[1]])
264                        && read_res_block(&*iv.s, &mut row[off[1]..off[2]])
265                        && read_res_block(&*iv.y_c, &mut row[off[2]..off[3]])
266                        && read_res_block(&*iv.y_d, &mut row[off[3]..off[4]])
267                        && read_res_block(&*iv.z_l, &mut row[off[4]..off[5]])
268                        && read_res_block(&*iv.z_u, &mut row[off[5]..off[6]])
269                        && read_res_block(&*iv.v_l, &mut row[off[6]..off[7]])
270                        && read_res_block(&*iv.v_u, &mut row[off[7]..off[8]]);
271                },
272            );
273            match fast {
274                Some(true) => return true,
275                Some(false) => return false,
276                None => { /* fall through to per-RHS loop */ }
277            }
278        }
279
280        // Per-RHS fallback: reuse one frozen rhs and one mut sol across
281        // all n_rhs `solve` calls.
282        let rhs_mut0 = self.template.make_new_zeroed();
283        let mut rhs_iv = rhs_mut0.freeze();
284        let mut res_iv = self.template.make_new_zeroed();
285
286        let mut pd_ref = self.pd.borrow_mut();
287        for k in 0..n_rhs {
288            let rhs_row = &rhs_flat[k * total..(k + 1) * total];
289            let lhs_row = &mut lhs_flat[k * total..(k + 1) * total];
290
291            if !write_rhs_block(&mut rhs_iv.x, &rhs_row[off[0]..off[1]])
292                || !write_rhs_block(&mut rhs_iv.s, &rhs_row[off[1]..off[2]])
293                || !write_rhs_block(&mut rhs_iv.y_c, &rhs_row[off[2]..off[3]])
294                || !write_rhs_block(&mut rhs_iv.y_d, &rhs_row[off[3]..off[4]])
295                || !write_rhs_block(&mut rhs_iv.z_l, &rhs_row[off[4]..off[5]])
296                || !write_rhs_block(&mut rhs_iv.z_u, &rhs_row[off[5]..off[6]])
297                || !write_rhs_block(&mut rhs_iv.v_l, &rhs_row[off[6]..off[7]])
298                || !write_rhs_block(&mut rhs_iv.v_u, &rhs_row[off[7]..off[8]])
299            {
300                return false;
301            }
302
303            let ok = pd_ref.solve(
304                &self.data,
305                &self.cq,
306                &self.nlp,
307                1.0,
308                0.0,
309                &rhs_iv,
310                &mut res_iv,
311                /* allow_inexact = */ true,
312                /* improve_solution = */ false,
313            );
314            if !ok {
315                return false;
316            }
317
318            if !read_res_block(&*res_iv.x, &mut lhs_row[off[0]..off[1]])
319                || !read_res_block(&*res_iv.s, &mut lhs_row[off[1]..off[2]])
320                || !read_res_block(&*res_iv.y_c, &mut lhs_row[off[2]..off[3]])
321                || !read_res_block(&*res_iv.y_d, &mut lhs_row[off[3]..off[4]])
322                || !read_res_block(&*res_iv.z_l, &mut lhs_row[off[4]..off[5]])
323                || !read_res_block(&*res_iv.z_u, &mut lhs_row[off[5]..off[6]])
324                || !read_res_block(&*res_iv.v_l, &mut lhs_row[off[6]..off[7]])
325                || !read_res_block(&*res_iv.v_u, &mut lhs_row[off[7]..off[8]])
326            {
327                return false;
328            }
329        }
330        true
331    }
332}
333
334/// Write `slice` into the `DenseVector` behind `b` in place. Used by
335/// the fast path's `write_rhs` closure, where the new
336/// `PdFullSpaceSolver::solve_many_cached` API hands back an
337/// `IteratesVectorMut` (Box-backed blocks).
338fn write_rhs_box(b: &mut Box<dyn pounce_linalg::vector::Vector>, slice: &[Number]) -> bool {
339    if slice.is_empty() {
340        return true;
341    }
342    let Some(dv) = b.as_any_mut().downcast_mut::<DenseVector>() else {
343        return false;
344    };
345    dv.set_values(slice);
346    true
347}
348
349/// Write `slice` into the `DenseVector` behind `rc` in place. Returns
350/// `false` if the Rc is unexpectedly shared (would indicate a bug in
351/// `PdFullSpaceSolver::solve`'s borrow discipline — it should never
352/// `Rc::clone` from the rhs vector) or if the block is not a
353/// `DenseVector`.
354fn write_rhs_block(rc: &mut Rc<dyn pounce_linalg::vector::Vector>, slice: &[Number]) -> bool {
355    if slice.is_empty() {
356        return true;
357    }
358    let Some(v) = Rc::get_mut(rc) else {
359        return false;
360    };
361    let Some(dv) = v.as_any_mut().downcast_mut::<DenseVector>() else {
362        return false;
363    };
364    dv.set_values(slice);
365    true
366}
367
368/// Read the `DenseVector` behind `blk` into `dst`. Handles the
369/// homogeneous case (empty z/v blocks for a TNLP with no bounds) by
370/// broadcasting the scalar rather than calling `expanded_values()`,
371/// which would allocate a fresh `Vec<Number>` every call.
372fn read_res_block(blk: &dyn pounce_linalg::vector::Vector, dst: &mut [Number]) -> bool {
373    if dst.is_empty() {
374        return true;
375    }
376    let Some(dv) = blk.as_any().downcast_ref::<DenseVector>() else {
377        return false;
378    };
379    if dv.is_homogeneous() {
380        let s = dv.scalar();
381        for x in dst.iter_mut() {
382            *x = s;
383        }
384    } else {
385        dst.copy_from_slice(dv.values());
386    }
387    true
388}
389
390impl SensBacksolver for PdSensBacksolver {
391    fn dim(&self) -> usize {
392        self.dims.iter().sum()
393    }
394
395    fn solve(&self, rhs: &[Number], lhs: &mut [Number]) -> bool {
396        let total = self.dim();
397        if rhs.len() != total || lhs.len() != total {
398            return false;
399        }
400        // Pack rhs into block form.
401        let rhs_mut = match self.pack(rhs) {
402            Ok(v) => v,
403            Err(()) => return false,
404        };
405        let rhs_iv = rhs_mut.freeze();
406        // Fresh result slot, zeroed.
407        let mut res_iv = self.template.make_new_zeroed();
408
409        // K · lhs = rhs   ⇒   solve(α=1, β=0, rhs, res) writes
410        // res = K⁻¹ · rhs.
411        //
412        // `allow_inexact=true` mirrors upstream sIPOPT's
413        // `SensSimpleBacksolver`: skip `PdFullSpaceSolver`'s iterative-
414        // refinement loop and accept the first back-solve against the
415        // held factor. The IPM-level refinement (`min_refinement_steps
416        // = 1`, residual_ratio_max = 1e-10`) is there to clean up
417        // numerical noise during forward IPM steps; for the held-factor
418        // back-solve used by sens / JaxProblem bwd, it ~doubles the
419        // per-call cost and produces gains that are below `tol`. Under
420        // `jax.jacrev` over a JaxProblem solve this dominates the wall
421        // time at moderate `n+m` (pounce#77 follow-up).
422        let ok = {
423            let mut pd_ref = self.pd.borrow_mut();
424            pd_ref.solve(
425                &self.data,
426                &self.cq,
427                &self.nlp,
428                1.0,
429                0.0,
430                &rhs_iv,
431                &mut res_iv,
432                /* allow_inexact = */ true,
433                /* improve_solution = */ false,
434            )
435        };
436        if !ok {
437            return false;
438        }
439        self.unpack(&res_iv, lhs).is_ok()
440    }
441}