Skip to main content

pounce_algorithm/kkt/
pd_full_space_solver.rs

1//! Full-space PD system solver — port of
2//! `Algorithm/IpPDFullSpaceSolver.{hpp,cpp}`.
3//!
4//! Iterative refinement on the FULL 8-block primal-dual KKT system,
5//! driving the augmented-system solver repeatedly. See
6//! `KKT_SYSTEM.md` §5 for the refinement-quit criteria. The outer
7//! loop alternates between back-solves and quality escalation
8//! (`AugSystemSolver::increase_quality()` and `pretend_singular`).
9
10use crate::ipopt_cq::IpoptCqHandle;
11use crate::ipopt_data::IpoptDataHandle;
12use crate::ipopt_nlp::IpoptNlp;
13use crate::iterates_vector::{IteratesVector, IteratesVectorMut};
14use crate::kkt::aug_system_solver::{AugSysCoeffs, AugSysRhs, AugSysSol, AugSystemSolver};
15use crate::kkt::pd_system_solver::PdSystemSolver;
16use crate::kkt::perturbation_handler::PdPerturbationHandler;
17use pounce_common::tagged::Tag;
18use pounce_common::types::{Index, Number};
19use pounce_linalg::dense_vector::DenseVector;
20use pounce_linalg::expansion_matrix::ExpansionMatrix;
21use pounce_linalg::{Matrix, SymMatrix, Vector};
22use pounce_linsol::ESymSolverStatus;
23use std::cell::RefCell;
24use std::rc::Rc;
25
26pub struct PdFullSpaceSolver {
27    aug_solver: Box<dyn AugSystemSolver>,
28    perturb: Rc<RefCell<PdPerturbationHandler>>,
29    pub min_refinement_steps: Index,
30    pub max_refinement_steps: Index,
31    pub residual_ratio_max: Number,
32    pub residual_ratio_singular: Number,
33    pub residual_improvement_factor: Number,
34    /// Negative-curvature test tolerance (`neg_curv_test_tol_`). Zero
35    /// disables the heuristic; matches upstream's `RegisterOptions`
36    /// default. The non-zero branch is not exercised in v1.0.
37    pub neg_curv_test_tol: Number,
38    /// Mirrors `augsys_improved_`. Set by quality-escalation; cleared
39    /// each time the cached aug-system data changes.
40    augsys_improved: bool,
41    /// Mirrors upstream's `dummy_cache_` hit/miss. `false` ⇒ the next
42    /// `solve_once` is operating on a *new* augmented matrix and must
43    /// run the `ConsiderNewSystem` + perturbation-escalation path;
44    /// `true` ⇒ the matrix is identical to the previous successful
45    /// `solve_once`, so we can reuse `CurrentPerturbation` and just do
46    /// a single back-solve (the iterative-refinement / quality-retry
47    /// re-call path). Reset to `false` at the start of every outer
48    /// `solve()` invocation since each outer iter delivers a fresh
49    /// matrix from the algorithm's perspective.
50    matrix_considered: bool,
51    /// Tags of the 13 dependencies (W, J_c, J_d, z_L, z_U, v_L, v_U,
52    /// slack_x_L, slack_x_U, slack_s_L, slack_s_U, sigma_x, sigma_s)
53    /// at the time `matrix_considered` was last set to `true`. Mirrors
54    /// upstream's `dummy_cache_` keyed on the same 13 `TaggedObject`s
55    /// (`IpPDFullSpaceSolver.cpp:430-448`). Reset to `None` whenever
56    /// any tag changes.
57    last_dep_tags: Option<[Tag; 13]>,
58    last_status: Option<ESymSolverStatus>,
59}
60
61impl PdFullSpaceSolver {
62    pub fn new(
63        aug_solver: Box<dyn AugSystemSolver>,
64        perturb: Rc<RefCell<PdPerturbationHandler>>,
65    ) -> Self {
66        Self {
67            aug_solver,
68            perturb,
69            // Defaults from `IpPDFullSpaceSolver.cpp:RegisterOptions`.
70            min_refinement_steps: 1,
71            max_refinement_steps: 10,
72            residual_ratio_max: 1e-10,
73            residual_ratio_singular: 1e-5,
74            residual_improvement_factor: 0.999_999_999,
75            neg_curv_test_tol: 0.0,
76            augsys_improved: false,
77            matrix_considered: false,
78            last_dep_tags: None,
79            last_status: None,
80        }
81    }
82
83    pub fn aug_solver(&self) -> &dyn AugSystemSolver {
84        &*self.aug_solver
85    }
86
87    pub fn aug_solver_mut(&mut self) -> &mut dyn AugSystemSolver {
88        &mut *self.aug_solver
89    }
90
91    /// Replace the underlying [`AugSystemSolver`] by passing the
92    /// existing one through the supplied wrapper closure. Used by the
93    /// restoration phase to decorate the inner `StdAugSystemSolver`
94    /// with `AugRestoSystemSolver` (which performs the 8-block →
95    /// 4-block Schur reduction before delegating).
96    pub fn wrap_aug_solver<F>(&mut self, wrap: F)
97    where
98        F: FnOnce(Box<dyn AugSystemSolver>) -> Box<dyn AugSystemSolver>,
99    {
100        // Take the inner aug solver out via a temporary noop, wrap it,
101        // and slot the wrapped one back in. The placeholder is never
102        // observed externally because we replace it before returning.
103        let noop: Box<dyn AugSystemSolver> = Box::new(NoopAugSolver);
104        let inner = std::mem::replace(&mut self.aug_solver, noop);
105        self.aug_solver = wrap(inner);
106    }
107
108    /// Solve the full PD system. `res = α · M⁻¹ · rhs + β · res_in`,
109    /// matching `IpPDFullSpaceSolver::Solve`. Returns `true` on
110    /// success. The iterate fields used to assemble the system are
111    /// pulled from `data` (`W`, `curr`) and `cq` (jacobians, slacks,
112    /// sigmas).
113    #[allow(clippy::too_many_arguments)]
114    pub fn solve(
115        &mut self,
116        data: &IpoptDataHandle,
117        cq: &IpoptCqHandle,
118        nlp: &Rc<RefCell<dyn IpoptNlp>>,
119        alpha: Number,
120        beta: Number,
121        rhs: &IteratesVector,
122        res: &mut IteratesVectorMut,
123        allow_inexact: bool,
124        improve_solution: bool,
125    ) -> bool {
126        debug_assert!(!allow_inexact || !improve_solution);
127        debug_assert!(!improve_solution || beta == 0.0);
128
129        // Snapshot the incoming `res` if β ≠ 0 (we add it back at the
130        // end via `res = α · sol + β · copy_res`).
131        let copy_res: Option<IteratesVector> = if beta != 0.0 {
132            Some(snapshot_mut(res))
133        } else {
134            None
135        };
136
137        // Pull all blocks once. None of these change during the
138        // refinement / escalation loop, so collecting them here
139        // matches upstream's structure (lines 168-189).
140        let w = data
141            .borrow()
142            .w
143            .clone()
144            .unwrap_or_else(|| panic!("PdFullSpaceSolver::solve: IpoptData::w is unset"));
145        let cq_ref = cq.borrow();
146        let j_c = cq_ref.curr_jac_c();
147        let j_d = cq_ref.curr_jac_d();
148        let sigma_x = cq_ref.curr_sigma_x();
149        let sigma_s = cq_ref.curr_sigma_s();
150        let slack_x_l = cq_ref.curr_slack_x_l();
151        let slack_x_u = cq_ref.curr_slack_x_u();
152        let slack_s_l = cq_ref.curr_slack_s_l();
153        let slack_s_u = cq_ref.curr_slack_s_u();
154        drop(cq_ref);
155
156        let nlp_ref = nlp.borrow();
157        let px_l = nlp_ref.px_l();
158        let px_u = nlp_ref.px_u();
159        let pd_l = nlp_ref.pd_l();
160        let pd_u = nlp_ref.pd_u();
161        drop(nlp_ref);
162
163        let curr = {
164            let d = data.borrow();
165            d.curr
166                .clone()
167                .unwrap_or_else(|| panic!("PdFullSpaceSolver::solve: IpoptData::curr is unset"))
168        };
169
170        let blocks = SolveBlocks {
171            w: &*w,
172            j_c: &*j_c,
173            j_d: &*j_d,
174            px_l: &*px_l,
175            px_u: &*px_u,
176            pd_l: &*pd_l,
177            pd_u: &*pd_u,
178            z_l: &*curr.z_l,
179            z_u: &*curr.z_u,
180            v_l: &*curr.v_l,
181            v_u: &*curr.v_u,
182            slack_x_l: &*slack_x_l,
183            slack_x_u: &*slack_x_u,
184            slack_s_l: &*slack_s_l,
185            slack_s_u: &*slack_s_u,
186            sigma_x: &*sigma_x,
187            sigma_s: &*sigma_s,
188        };
189
190        // Mirror upstream's `dummy_cache_` lookup
191        // (`IpPDFullSpaceSolver.cpp:430-450`): if all 13 dependency tags
192        // are unchanged since the last successful `solve()`, the matrix
193        // is "uptodate" — keep `matrix_considered = true` so the
194        // perturbation handler is NOT re-entered, and reuse the
195        // existing `augsys_improved_` state. On a cache miss, reset
196        // both flags.
197        let cur_tags: [Tag; 13] = [
198            blocks.w.as_tagged().get_tag(),
199            blocks.j_c.as_tagged().get_tag(),
200            blocks.j_d.as_tagged().get_tag(),
201            blocks.z_l.as_tagged().get_tag(),
202            blocks.z_u.as_tagged().get_tag(),
203            blocks.v_l.as_tagged().get_tag(),
204            blocks.v_u.as_tagged().get_tag(),
205            blocks.slack_x_l.as_tagged().get_tag(),
206            blocks.slack_x_u.as_tagged().get_tag(),
207            blocks.slack_s_l.as_tagged().get_tag(),
208            blocks.slack_s_u.as_tagged().get_tag(),
209            blocks.sigma_x.as_tagged().get_tag(),
210            blocks.sigma_s.as_tagged().get_tag(),
211        ];
212        let uptodate = self.last_dep_tags.map_or(false, |prev| prev == cur_tags);
213        if !uptodate {
214            if std::env::var_os("POUNCE_DBG_PD_TAGS").is_some() {
215                if let Some(prev) = self.last_dep_tags {
216                    let names = [
217                        "w",
218                        "j_c",
219                        "j_d",
220                        "z_l",
221                        "z_u",
222                        "v_l",
223                        "v_u",
224                        "slack_x_l",
225                        "slack_x_u",
226                        "slack_s_l",
227                        "slack_s_u",
228                        "sigma_x",
229                        "sigma_s",
230                    ];
231                    let mut diffs = String::new();
232                    for i in 0..13 {
233                        if prev[i] != cur_tags[i] {
234                            diffs.push_str(&format!(
235                                " {}({:?}→{:?})",
236                                names[i], prev[i], cur_tags[i]
237                            ));
238                        }
239                    }
240                    tracing::debug!(target: "pounce::linsol", "[PN_PD_TAGS] cache_miss diffs:{}", diffs);
241                } else {
242                    tracing::debug!(target: "pounce::linsol", "[PN_PD_TAGS] cache_miss first_solve");
243                }
244            }
245            self.last_dep_tags = Some(cur_tags);
246            self.matrix_considered = false;
247            self.augsys_improved = false;
248        }
249
250        let mut done = false;
251        let mut resolve_with_better_quality = false;
252        let mut pretend_singular = false;
253        let mut pretend_singular_last_time = false;
254        let mut improve = improve_solution;
255
256        while !done {
257            let solve_ok = if improve {
258                true
259            } else {
260                let ok = self.solve_once(
261                    data,
262                    &blocks,
263                    1.0,
264                    0.0,
265                    rhs,
266                    res,
267                    resolve_with_better_quality,
268                    pretend_singular,
269                );
270                resolve_with_better_quality = false;
271                pretend_singular = false;
272                ok
273            };
274            improve = false;
275
276            if !solve_ok {
277                return false;
278            }
279
280            if allow_inexact {
281                break;
282            }
283
284            // Initial residual.
285            let mut resid = res.fresh_zeroed();
286            self.compute_residuals(data, &blocks, rhs, res, &mut resid);
287            let mut residual_ratio = self.compute_residual_ratio(rhs, res, &resid);
288            let mut residual_ratio_old = residual_ratio;
289
290            let mut num_iter_ref: Index = 0;
291            let mut quit_refinement = false;
292
293            while !quit_refinement
294                && (num_iter_ref < self.min_refinement_steps
295                    || residual_ratio > self.residual_ratio_max)
296            {
297                let frozen_resid = resid.freeze();
298                let solve_ok = self.solve_once(
299                    data,
300                    &blocks,
301                    -1.0,
302                    1.0,
303                    &frozen_resid,
304                    res,
305                    resolve_with_better_quality,
306                    false,
307                );
308                resid = thaw(frozen_resid);
309                if !solve_ok {
310                    return false;
311                }
312
313                self.compute_residuals(data, &blocks, rhs, res, &mut resid);
314                residual_ratio = self.compute_residual_ratio(rhs, res, &resid);
315                num_iter_ref += 1;
316
317                if residual_ratio > self.residual_ratio_max
318                    && num_iter_ref > self.min_refinement_steps
319                    && (num_iter_ref > self.max_refinement_steps
320                        || residual_ratio > self.residual_improvement_factor * residual_ratio_old)
321                {
322                    quit_refinement = true;
323                    resolve_with_better_quality = false;
324
325                    if !pretend_singular_last_time {
326                        if !self.augsys_improved {
327                            self.augsys_improved = self.aug_solver.increase_quality();
328                            if self.augsys_improved {
329                                data.borrow_mut().append_info_string("q");
330                                resolve_with_better_quality = true;
331                            } else {
332                                pretend_singular = true;
333                            }
334                        } else {
335                            pretend_singular = true;
336                        }
337                        pretend_singular_last_time = pretend_singular;
338                        if pretend_singular {
339                            if residual_ratio < self.residual_ratio_singular {
340                                pretend_singular = false;
341                                data.borrow_mut().append_info_string("S");
342                            } else {
343                                data.borrow_mut().append_info_string("s");
344                            }
345                        }
346                    } else {
347                        pretend_singular = false;
348                    }
349                }
350
351                residual_ratio_old = residual_ratio;
352            }
353
354            done = !resolve_with_better_quality && !pretend_singular;
355        }
356
357        // Final assembly: res = α · res + β · copy_res.
358        if alpha != 0.0 {
359            res.scal(alpha);
360        }
361        if let Some(copy_res) = copy_res {
362            res.axpy(beta, &copy_res);
363        }
364
365        self.last_status = Some(ESymSolverStatus::Success);
366        true
367    }
368
369    /// Batched back-substitution against the cached KKT factor for
370    /// `n_rhs` right-hand sides, sharing one
371    /// `pounce_linsol::TSymLinearSolver::multi_solve` call with
372    /// `nrhs > 1`. Each column k pulls its RHS through `write_rhs(k,
373    /// &mut iv)` and emits its solution through `write_lhs(k, &iv)` —
374    /// closures over the caller's flat / strided buffer keep the
375    /// rhs/sol `IteratesVectorMut` scratch out of the API surface.
376    ///
377    /// Returns:
378    /// - `Some(true)`  — fast path executed against the cached factor.
379    /// - `Some(false)` — fast path was attempted but the linsol
380    ///   reported a back-solve failure.
381    /// - `None`        — fast path not taken. Either the matrix tags
382    ///   differ from the last successful [`Self::solve`] (cache miss),
383    ///   the matrix has not been considered yet, or the underlying
384    ///   `AugSystemSolver` does not implement
385    ///   [`AugSystemSolver::try_resolve_many_flat`]. The caller should
386    ///   fall back to looping [`Self::solve`].
387    ///
388    /// Used by `pounce_sensitivity::PdSensBacksolver::solve_many` for
389    /// the JaxProblem `jacrev` backward path, where every cotangent
390    /// re-solves against the same converged factor (pounce#77 follow-up).
391    pub fn solve_many_cached<F1, F2>(
392        &mut self,
393        data: &IpoptDataHandle,
394        cq: &IpoptCqHandle,
395        nlp: &Rc<RefCell<dyn IpoptNlp>>,
396        n_rhs: usize,
397        mut write_rhs: F1,
398        mut write_lhs: F2,
399    ) -> Option<bool>
400    where
401        F1: FnMut(usize, &mut IteratesVectorMut),
402        F2: FnMut(usize, &IteratesVectorMut),
403    {
404        if n_rhs == 0 {
405            return Some(true);
406        }
407
408        // Pull all blocks (same shape as `solve()`).
409        let w = data.borrow().w.clone()?;
410        let cq_ref = cq.borrow();
411        let j_c = cq_ref.curr_jac_c();
412        let j_d = cq_ref.curr_jac_d();
413        let sigma_x = cq_ref.curr_sigma_x();
414        let sigma_s = cq_ref.curr_sigma_s();
415        let slack_x_l = cq_ref.curr_slack_x_l();
416        let slack_x_u = cq_ref.curr_slack_x_u();
417        let slack_s_l = cq_ref.curr_slack_s_l();
418        let slack_s_u = cq_ref.curr_slack_s_u();
419        drop(cq_ref);
420
421        let nlp_ref = nlp.borrow();
422        let px_l = nlp_ref.px_l();
423        let px_u = nlp_ref.px_u();
424        let pd_l = nlp_ref.pd_l();
425        let pd_u = nlp_ref.pd_u();
426        drop(nlp_ref);
427
428        let curr = data.borrow().curr.clone()?;
429
430        let blocks = SolveBlocks {
431            w: &*w,
432            j_c: &*j_c,
433            j_d: &*j_d,
434            px_l: &*px_l,
435            px_u: &*px_u,
436            pd_l: &*pd_l,
437            pd_u: &*pd_u,
438            z_l: &*curr.z_l,
439            z_u: &*curr.z_u,
440            v_l: &*curr.v_l,
441            v_u: &*curr.v_u,
442            slack_x_l: &*slack_x_l,
443            slack_x_u: &*slack_x_u,
444            slack_s_l: &*slack_s_l,
445            slack_s_u: &*slack_s_u,
446            sigma_x: &*sigma_x,
447            sigma_s: &*sigma_s,
448        };
449
450        // Cache-tag check (same 13 tags as `solve()`). If the matrix
451        // has changed since the last successful solve, or we never
452        // marked it as considered, bail and let the caller take the
453        // per-RHS path.
454        let cur_tags: [Tag; 13] = [
455            blocks.w.as_tagged().get_tag(),
456            blocks.j_c.as_tagged().get_tag(),
457            blocks.j_d.as_tagged().get_tag(),
458            blocks.z_l.as_tagged().get_tag(),
459            blocks.z_u.as_tagged().get_tag(),
460            blocks.v_l.as_tagged().get_tag(),
461            blocks.v_u.as_tagged().get_tag(),
462            blocks.slack_x_l.as_tagged().get_tag(),
463            blocks.slack_x_u.as_tagged().get_tag(),
464            blocks.slack_s_l.as_tagged().get_tag(),
465            blocks.slack_s_u.as_tagged().get_tag(),
466            blocks.sigma_x.as_tagged().get_tag(),
467            blocks.sigma_s.as_tagged().get_tag(),
468        ];
469        if !self.matrix_considered || !self.last_dep_tags.map_or(false, |prev| prev == cur_tags) {
470            return None;
471        }
472
473        // Coeffs reuse the perturbation stashed by the most recent
474        // `solve_once`. `current_perturbation()` returns the same
475        // values that solve_once wrote into `data.perturbations`.
476        let d = self.perturb.borrow().current_perturbation();
477        let coeffs = AugSysCoeffs {
478            w: Some(blocks.w),
479            w_factor: 1.0,
480            d_x: Some(blocks.sigma_x),
481            delta_x: d.delta_x,
482            d_s: Some(blocks.sigma_s),
483            delta_s: d.delta_s,
484            j_c: blocks.j_c,
485            d_c: None,
486            delta_c: d.delta_c,
487            j_d: blocks.j_d,
488            d_d: None,
489            delta_d: d.delta_d,
490        };
491
492        let n_x = curr.x.dim() as usize;
493        let n_s = curr.s.dim() as usize;
494        let n_y_c = curr.y_c.dim() as usize;
495        let n_y_d = curr.y_d.dim() as usize;
496        let aug_dim = n_x + n_s + n_y_c + n_y_d;
497
498        // Scratch — one set of Box allocs, reused across every column.
499        let mut rhs_iv = curr.make_new_zeroed();
500        let mut sol_iv = curr.make_new_zeroed();
501        let mut aug_rhs_x_box: Box<dyn Vector> = curr.x.make_new();
502        let mut aug_rhs_s_box: Box<dyn Vector> = curr.s.make_new();
503
504        // Column-major `(aug_dim, n_rhs)` packed buffer — single
505        // allocation that the linsol's `multi_solve` writes solutions
506        // back into in place.
507        let mut aug_packed = vec![0.0 as Number; aug_dim * n_rhs];
508
509        // Phase 1: populate aug_packed column-by-column. The aug-system
510        // RHS is `[aug_rhs_x | aug_rhs_s | rhs.y_c | rhs.y_d]`, where
511        //   aug_rhs_x = rhs.x + Px_L·S_xL⁻¹·z_L − Px_U·S_xU⁻¹·z_U
512        //   aug_rhs_s = rhs.s + Pd_L·S_sL⁻¹·v_L − Pd_U·S_sU⁻¹·v_U
513        // matching `solve_once`'s aug-RHS build.
514        for k in 0..n_rhs {
515            write_rhs(k, &mut rhs_iv);
516
517            aug_rhs_x_box.copy(&*rhs_iv.x);
518            blocks
519                .px_l
520                .add_m_sinv_z(1.0, blocks.slack_x_l, &*rhs_iv.z_l, &mut *aug_rhs_x_box);
521            blocks
522                .px_u
523                .add_m_sinv_z(-1.0, blocks.slack_x_u, &*rhs_iv.z_u, &mut *aug_rhs_x_box);
524
525            aug_rhs_s_box.copy(&*rhs_iv.s);
526            blocks
527                .pd_l
528                .add_m_sinv_z(1.0, blocks.slack_s_l, &*rhs_iv.v_l, &mut *aug_rhs_s_box);
529            blocks
530                .pd_u
531                .add_m_sinv_z(-1.0, blocks.slack_s_u, &*rhs_iv.v_u, &mut *aug_rhs_s_box);
532
533            let col = &mut aug_packed[k * aug_dim..(k + 1) * aug_dim];
534            copy_vector_to_slice(&*aug_rhs_x_box, &mut col[..n_x]);
535            copy_vector_to_slice(&*aug_rhs_s_box, &mut col[n_x..n_x + n_s]);
536            copy_vector_to_slice(&*rhs_iv.y_c, &mut col[n_x + n_s..n_x + n_s + n_y_c]);
537            copy_vector_to_slice(&*rhs_iv.y_d, &mut col[n_x + n_s + n_y_c..]);
538        }
539
540        // Phase 2: single batched back-substitution.
541        let status = self
542            .aug_solver
543            .try_resolve_many_flat(&coeffs, &mut aug_packed, n_rhs)?;
544        if status != ESymSolverStatus::Success {
545            self.last_status = Some(status);
546            return Some(false);
547        }
548        self.last_status = Some(status);
549
550        // Phase 3: unpack each column into `sol_iv`, run the bound-
551        // multiplier expansion, hand the result to the caller. We have
552        // to re-invoke `write_rhs` because expand_bound_multipliers
553        // reads `rhs.z_l/z_u/v_l/v_u` and we re-used `rhs_iv` across
554        // all columns in phase 1.
555        for k in 0..n_rhs {
556            write_rhs(k, &mut rhs_iv);
557
558            let col = &aug_packed[k * aug_dim..(k + 1) * aug_dim];
559            set_vector_from_slice(&mut *sol_iv.x, &col[..n_x]);
560            set_vector_from_slice(&mut *sol_iv.s, &col[n_x..n_x + n_s]);
561            set_vector_from_slice(&mut *sol_iv.y_c, &col[n_x + n_s..n_x + n_s + n_y_c]);
562            set_vector_from_slice(&mut *sol_iv.y_d, &col[n_x + n_s + n_y_c..]);
563
564            // Inline expand_bound_multipliers — that helper takes
565            // `&IteratesVector` (Rc-backed) but our `rhs_iv` is
566            // `IteratesVectorMut` (Box-backed). The four
567            // `sinv_blrm_zmt_dbr` calls work on `&dyn Vector` either
568            // way.
569            blocks.px_l.sinv_blrm_zmt_dbr(
570                -1.0,
571                blocks.slack_x_l,
572                &*rhs_iv.z_l,
573                blocks.z_l,
574                &*sol_iv.x,
575                &mut *sol_iv.z_l,
576            );
577            blocks.px_u.sinv_blrm_zmt_dbr(
578                1.0,
579                blocks.slack_x_u,
580                &*rhs_iv.z_u,
581                blocks.z_u,
582                &*sol_iv.x,
583                &mut *sol_iv.z_u,
584            );
585            blocks.pd_l.sinv_blrm_zmt_dbr(
586                -1.0,
587                blocks.slack_s_l,
588                &*rhs_iv.v_l,
589                blocks.v_l,
590                &*sol_iv.s,
591                &mut *sol_iv.v_l,
592            );
593            blocks.pd_u.sinv_blrm_zmt_dbr(
594                1.0,
595                blocks.slack_s_u,
596                &*rhs_iv.v_u,
597                blocks.v_u,
598                &*sol_iv.s,
599                &mut *sol_iv.v_u,
600            );
601
602            write_lhs(k, &sol_iv);
603        }
604
605        Some(true)
606    }
607
608    /// Flat-slice cached-factor multi-RHS path. Same cache-check
609    /// semantics as [`Self::solve_many_cached`] but operates on
610    /// row-major `(n_rhs, total)` flat buffers without going through
611    /// `IteratesVectorMut` or any `dyn Vector` / `dyn Matrix` dispatch
612    /// in the per-RHS inner loops — the eight source blocks
613    /// (`slack_{x,s}_{l,u}`, `z_{l,u}`, `v_{l,u}`) get downcast to
614    /// `DenseVector` once at the top, the four bound-expansion matrices
615    /// (`px_l`, `px_u`, `pd_l`, `pd_u`) get downcast to
616    /// `ExpansionMatrix` once, and Phase 1 / Phase 3 then run as raw
617    /// `&[Number]` / `&mut [Number]` arithmetic on the flat buffers.
618    ///
619    /// `total` is the sum of the eight `block_dims` entries (in the
620    /// same `(x, s, y_c, y_d, z_l, z_u, v_l, v_u)` order that
621    /// `IteratesVector` uses); `rhs_flat.len() == lhs_flat.len() ==
622    /// n_rhs * total`.
623    ///
624    /// Returns `None` (caller should fall back to
625    /// [`Self::solve_many_cached`]) when:
626    /// - the cache check fails (matrix tags differ),
627    /// - any block source vector is not a `DenseVector` or is
628    ///   homogeneous (uniform-scalar) on a non-empty block,
629    /// - any bound-expansion matrix is not an `ExpansionMatrix`,
630    /// - the underlying `AugSystemSolver` doesn't implement
631    ///   [`AugSystemSolver::try_resolve_many_flat`].
632    ///
633    /// Returns `Some(true)` on success, `Some(false)` on linsol back-
634    /// solve failure.
635    ///
636    /// Used by `pounce_sensitivity::PdSensBacksolver::solve_many` as
637    /// the fastest tier of the JaxProblem `jacrev` backward path
638    /// (pounce#77 follow-up).
639    #[allow(clippy::too_many_arguments)]
640    pub fn solve_many_cached_flat(
641        &mut self,
642        data: &IpoptDataHandle,
643        cq: &IpoptCqHandle,
644        nlp: &Rc<RefCell<dyn IpoptNlp>>,
645        n_rhs: usize,
646        rhs_flat: &[Number],
647        lhs_flat: &mut [Number],
648        block_dims: [usize; 8],
649    ) -> Option<bool> {
650        if n_rhs == 0 {
651            return Some(true);
652        }
653        let total: usize = block_dims.iter().sum();
654        if rhs_flat.len() != n_rhs * total || lhs_flat.len() != n_rhs * total {
655            return Some(false);
656        }
657        let mut off = [0usize; 9];
658        for i in 0..8 {
659            off[i + 1] = off[i] + block_dims[i];
660        }
661        let n_x = block_dims[0];
662        let n_s = block_dims[1];
663        let n_y_c = block_dims[2];
664        let n_y_d = block_dims[3];
665
666        // Pull all blocks (same shape as `solve()`).
667        let w = data.borrow().w.clone()?;
668        let cq_ref = cq.borrow();
669        let j_c = cq_ref.curr_jac_c();
670        let j_d = cq_ref.curr_jac_d();
671        let sigma_x = cq_ref.curr_sigma_x();
672        let sigma_s = cq_ref.curr_sigma_s();
673        let slack_x_l = cq_ref.curr_slack_x_l();
674        let slack_x_u = cq_ref.curr_slack_x_u();
675        let slack_s_l = cq_ref.curr_slack_s_l();
676        let slack_s_u = cq_ref.curr_slack_s_u();
677        drop(cq_ref);
678
679        let nlp_ref = nlp.borrow();
680        let px_l = nlp_ref.px_l();
681        let px_u = nlp_ref.px_u();
682        let pd_l = nlp_ref.pd_l();
683        let pd_u = nlp_ref.pd_u();
684        drop(nlp_ref);
685
686        let curr = data.borrow().curr.clone()?;
687
688        // Cache-tag check (same 13 tags as `solve()`).
689        let cur_tags: [Tag; 13] = [
690            w.as_tagged().get_tag(),
691            j_c.as_tagged().get_tag(),
692            j_d.as_tagged().get_tag(),
693            curr.z_l.as_tagged().get_tag(),
694            curr.z_u.as_tagged().get_tag(),
695            curr.v_l.as_tagged().get_tag(),
696            curr.v_u.as_tagged().get_tag(),
697            slack_x_l.as_tagged().get_tag(),
698            slack_x_u.as_tagged().get_tag(),
699            slack_s_l.as_tagged().get_tag(),
700            slack_s_u.as_tagged().get_tag(),
701            sigma_x.as_tagged().get_tag(),
702            sigma_s.as_tagged().get_tag(),
703        ];
704        if !self.matrix_considered || !self.last_dep_tags.map_or(false, |prev| prev == cur_tags) {
705            return None;
706        }
707
708        // Concrete downcasts. Bail to closure-based fallback on any
709        // type mismatch (homogeneous-on-non-empty included — the math
710        // below assumes a real `[Number]` slice for slack / z / v).
711        let slack_x_l_d = dense_slice_or_none(&*slack_x_l, block_dims[4])?;
712        let slack_x_u_d = dense_slice_or_none(&*slack_x_u, block_dims[5])?;
713        let slack_s_l_d = dense_slice_or_none(&*slack_s_l, block_dims[6])?;
714        let slack_s_u_d = dense_slice_or_none(&*slack_s_u, block_dims[7])?;
715        let blocks_z_l_d = dense_slice_or_none(&*curr.z_l, block_dims[4])?;
716        let blocks_z_u_d = dense_slice_or_none(&*curr.z_u, block_dims[5])?;
717        let blocks_v_l_d = dense_slice_or_none(&*curr.v_l, block_dims[6])?;
718        let blocks_v_u_d = dense_slice_or_none(&*curr.v_u, block_dims[7])?;
719
720        let exp_x_l = exp_pos_or_none(&*px_l)?;
721        let exp_x_u = exp_pos_or_none(&*px_u)?;
722        let exp_s_l = exp_pos_or_none(&*pd_l)?;
723        let exp_s_u = exp_pos_or_none(&*pd_u)?;
724
725        // Coeffs reuse the perturbation stashed by the most recent
726        // `solve_once`.
727        let d = self.perturb.borrow().current_perturbation();
728        let coeffs = AugSysCoeffs {
729            w: Some(&*w),
730            w_factor: 1.0,
731            d_x: Some(&*sigma_x),
732            delta_x: d.delta_x,
733            d_s: Some(&*sigma_s),
734            delta_s: d.delta_s,
735            j_c: &*j_c,
736            d_c: None,
737            delta_c: d.delta_c,
738            j_d: &*j_d,
739            d_d: None,
740            delta_d: d.delta_d,
741        };
742
743        let aug_dim = n_x + n_s + n_y_c + n_y_d;
744        // Column-major `(aug_dim, n_rhs)` packed buffer, single alloc.
745        let mut aug_packed = vec![0.0 as Number; aug_dim * n_rhs];
746
747        // ---------------- Phase 1 ----------------
748        // For each k: build the aug-system RHS into column k of
749        // aug_packed, all inline against raw slices.
750        for k in 0..n_rhs {
751            let r_base = k * total;
752            let rhs_x = &rhs_flat[r_base + off[0]..r_base + off[1]];
753            let rhs_s = &rhs_flat[r_base + off[1]..r_base + off[2]];
754            let rhs_y_c = &rhs_flat[r_base + off[2]..r_base + off[3]];
755            let rhs_y_d = &rhs_flat[r_base + off[3]..r_base + off[4]];
756            let rhs_z_l = &rhs_flat[r_base + off[4]..r_base + off[5]];
757            let rhs_z_u = &rhs_flat[r_base + off[5]..r_base + off[6]];
758            let rhs_v_l = &rhs_flat[r_base + off[6]..r_base + off[7]];
759            let rhs_v_u = &rhs_flat[r_base + off[7]..r_base + off[8]];
760
761            let aug_col = &mut aug_packed[k * aug_dim..(k + 1) * aug_dim];
762            let (aug_x, rest) = aug_col.split_at_mut(n_x);
763            let (aug_s, rest) = rest.split_at_mut(n_s);
764            let (aug_y_c, aug_y_d) = rest.split_at_mut(n_y_c);
765
766            // aug_x = rhs_x + Px_L · S_xL⁻¹ · z_L − Px_U · S_xU⁻¹ · z_U
767            aug_x.copy_from_slice(rhs_x);
768            scatter_add_div(aug_x, exp_x_l, rhs_z_l, slack_x_l_d, 1.0);
769            scatter_add_div(aug_x, exp_x_u, rhs_z_u, slack_x_u_d, -1.0);
770            // aug_s = rhs_s + Pd_L · S_sL⁻¹ · v_L − Pd_U · S_sU⁻¹ · v_U
771            aug_s.copy_from_slice(rhs_s);
772            scatter_add_div(aug_s, exp_s_l, rhs_v_l, slack_s_l_d, 1.0);
773            scatter_add_div(aug_s, exp_s_u, rhs_v_u, slack_s_u_d, -1.0);
774            aug_y_c.copy_from_slice(rhs_y_c);
775            aug_y_d.copy_from_slice(rhs_y_d);
776        }
777
778        // ---------------- Phase 2 ----------------
779        let status = self
780            .aug_solver
781            .try_resolve_many_flat(&coeffs, &mut aug_packed, n_rhs)?;
782        if status != ESymSolverStatus::Success {
783            self.last_status = Some(status);
784            return Some(false);
785        }
786        self.last_status = Some(status);
787
788        // ---------------- Phase 3 ----------------
789        // For each k: copy sol_x/s/y_c/y_d into lhs_flat, then build
790        // sol_z_l/z_u/v_l/v_u from the bound-multiplier expansion.
791        for k in 0..n_rhs {
792            let r_base = k * total;
793            let rhs_z_l = &rhs_flat[r_base + off[4]..r_base + off[5]];
794            let rhs_z_u = &rhs_flat[r_base + off[5]..r_base + off[6]];
795            let rhs_v_l = &rhs_flat[r_base + off[6]..r_base + off[7]];
796            let rhs_v_u = &rhs_flat[r_base + off[7]..r_base + off[8]];
797
798            let aug_col = &aug_packed[k * aug_dim..(k + 1) * aug_dim];
799            let sol_x = &aug_col[..n_x];
800            let sol_s = &aug_col[n_x..n_x + n_s];
801            let sol_y_c = &aug_col[n_x + n_s..n_x + n_s + n_y_c];
802            let sol_y_d = &aug_col[n_x + n_s + n_y_c..];
803
804            let l_base = k * total;
805            let (lhs_xs, lhs_zv) = lhs_flat[l_base..l_base + total].split_at_mut(off[4]);
806            let (lhs_x, rest) = lhs_xs.split_at_mut(n_x);
807            let (lhs_s, rest) = rest.split_at_mut(n_s);
808            let (lhs_y_c, lhs_y_d) = rest.split_at_mut(n_y_c);
809            lhs_x.copy_from_slice(sol_x);
810            lhs_s.copy_from_slice(sol_s);
811            lhs_y_c.copy_from_slice(sol_y_c);
812            lhs_y_d.copy_from_slice(sol_y_d);
813
814            let (lhs_z_l, rest) = lhs_zv.split_at_mut(block_dims[4]);
815            let (lhs_z_u, rest) = rest.split_at_mut(block_dims[5]);
816            let (lhs_v_l, lhs_v_u) = rest.split_at_mut(block_dims[6]);
817
818            // sol_z_l[i] = (rhs_z_l[i] − z_l[i] · sol_x[exp_x_l[i]]) / slack_x_l[i]
819            expand_bound_mult(
820                lhs_z_l,
821                rhs_z_l,
822                blocks_z_l_d,
823                sol_x,
824                exp_x_l,
825                slack_x_l_d,
826                -1.0,
827            );
828            // sol_z_u[i] = (rhs_z_u[i] + z_u[i] · sol_x[exp_x_u[i]]) / slack_x_u[i]
829            expand_bound_mult(
830                lhs_z_u,
831                rhs_z_u,
832                blocks_z_u_d,
833                sol_x,
834                exp_x_u,
835                slack_x_u_d,
836                1.0,
837            );
838            expand_bound_mult(
839                lhs_v_l,
840                rhs_v_l,
841                blocks_v_l_d,
842                sol_s,
843                exp_s_l,
844                slack_s_l_d,
845                -1.0,
846            );
847            expand_bound_mult(
848                lhs_v_u,
849                rhs_v_u,
850                blocks_v_u_d,
851                sol_s,
852                exp_s_u,
853                slack_s_u_d,
854                1.0,
855            );
856        }
857
858        Some(true)
859    }
860
861    /// One outer back-solve through the augmented system, including
862    /// the `Px_L · S_xL⁻¹ · z_L` lifts on the RHS and the bound-
863    /// multiplier expansion on the solution side. Mirrors
864    /// `IpPDFullSpaceSolver::SolveOnce`.
865    #[allow(clippy::too_many_arguments)]
866    fn solve_once(
867        &mut self,
868        data: &IpoptDataHandle,
869        b: &SolveBlocks<'_>,
870        alpha: Number,
871        beta: Number,
872        rhs: &IteratesVector,
873        res: &mut IteratesVectorMut,
874        _resolve_with_better_quality: bool,
875        mut pretend_singular: bool,
876    ) -> bool {
877        // Build aug-system primal RHS:
878        //   augRhs_x = rhs.x + Px_L · S_xL⁻¹ · z_L − Px_U · S_xU⁻¹ · z_U
879        let mut aug_rhs_x = rhs.x.make_new_copy();
880        b.px_l
881            .add_m_sinv_z(1.0, b.slack_x_l, &*rhs.z_l, &mut *aug_rhs_x);
882        b.px_u
883            .add_m_sinv_z(-1.0, b.slack_x_u, &*rhs.z_u, &mut *aug_rhs_x);
884
885        let mut aug_rhs_s = rhs.s.make_new_copy();
886        b.pd_l
887            .add_m_sinv_z(1.0, b.slack_s_l, &*rhs.v_l, &mut *aug_rhs_s);
888        b.pd_u
889            .add_m_sinv_z(-1.0, b.slack_s_u, &*rhs.v_u, &mut *aug_rhs_s);
890
891        // Solution slot for the aug-system (dx, ds, dy_c, dy_d).
892        let mut sol = res.fresh_zeroed();
893
894        // Number of negative eigenvalues we expect.
895        let num_neg_evals = rhs.y_c.dim() + rhs.y_d.dim();
896
897        let curr_mu = data.borrow().curr_mu;
898
899        // Upstream's `IpPDFullSpaceSolver::SolveOnce` (cpp:457-482)
900        // splits on `(uptodate && !pretend_singular)`: if the matrix is
901        // unchanged since the last `SolveOnce` and we are not faking a
902        // singularity, reuse the existing perturbation, do a single
903        // back-solve with `check_inertia=false`, and return. Iterative
904        // refinement and the post-`IncreaseQuality` retry both land
905        // here. Calling `ConsiderNewSystem` again on a same-matrix
906        // re-solve would corrupt the perturbation handler's
907        // `delta_x_last` bookkeeping.
908        if self.matrix_considered && !pretend_singular {
909            let d = self.perturb.borrow().current_perturbation();
910            let coeffs = AugSysCoeffs {
911                w: Some(b.w),
912                w_factor: 1.0,
913                d_x: Some(b.sigma_x),
914                delta_x: d.delta_x,
915                d_s: Some(b.sigma_s),
916                delta_s: d.delta_s,
917                j_c: b.j_c,
918                d_c: None,
919                delta_c: d.delta_c,
920                j_d: b.j_d,
921                d_d: None,
922                delta_d: d.delta_d,
923            };
924            let aug_rhs = AugSysRhs {
925                rhs_x: &*aug_rhs_x,
926                rhs_s: &*aug_rhs_s,
927                rhs_c: &*rhs.y_c,
928                rhs_d: &*rhs.y_d,
929            };
930            let mut aug_sol = AugSysSol {
931                sol_x: &mut *sol.x,
932                sol_s: &mut *sol.s,
933                sol_c: &mut *sol.y_c,
934                sol_d: &mut *sol.y_d,
935            };
936            // Same matrix, same perturbations, inertia already known —
937            // use the cached factor and avoid the per-call refactor
938            // that otherwise dominates MA57 wall-time on long iter-ref
939            // loops (cont5_2_4_l drops 97s → ~30s).
940            let retval = self.aug_solver.resolve(&coeffs, &aug_rhs, &mut aug_sol);
941            if retval != ESymSolverStatus::Success {
942                return false;
943            }
944            // Stash perturbations on data, expand bound multipliers,
945            // assemble final res, and return — skipping the
946            // escalation loop entirely (matches upstream's `if(uptodate
947            // && !pretend_singular)` branch in IpPDFullSpaceSolver.cpp).
948            {
949                let mut dm = data.borrow_mut();
950                dm.perturbations.delta_x = d.delta_x;
951                dm.perturbations.delta_s = d.delta_s;
952                dm.perturbations.delta_c = d.delta_c;
953                dm.perturbations.delta_d = d.delta_d;
954            }
955            expand_bound_multipliers(b, rhs, &mut sol);
956            let frozen_sol = sol.freeze();
957            res.add_one_vector(alpha, &frozen_sol, beta);
958            return true;
959        }
960
961        let mut deltas = self
962            .perturb
963            .borrow_mut()
964            .consider_new_system(curr_mu, Some(data));
965        let Some(mut d) = deltas.take() else {
966            return false;
967        };
968
969        let mut count = 0_i32;
970        let mut retval;
971        loop {
972            if pretend_singular {
973                retval = ESymSolverStatus::Singular;
974                pretend_singular = false;
975            } else {
976                count += 1;
977                let check_inertia = self.neg_curv_test_tol <= 0.0;
978                let coeffs = AugSysCoeffs {
979                    w: Some(b.w),
980                    w_factor: 1.0,
981                    d_x: Some(b.sigma_x),
982                    delta_x: d.delta_x,
983                    d_s: Some(b.sigma_s),
984                    delta_s: d.delta_s,
985                    j_c: b.j_c,
986                    d_c: None,
987                    delta_c: d.delta_c,
988                    j_d: b.j_d,
989                    d_d: None,
990                    delta_d: d.delta_d,
991                };
992                let aug_rhs = AugSysRhs {
993                    rhs_x: &*aug_rhs_x,
994                    rhs_s: &*aug_rhs_s,
995                    rhs_c: &*rhs.y_c,
996                    rhs_d: &*rhs.y_d,
997                };
998                let mut aug_sol = AugSysSol {
999                    sol_x: &mut *sol.x,
1000                    sol_s: &mut *sol.s,
1001                    sol_c: &mut *sol.y_c,
1002                    sol_d: &mut *sol.y_d,
1003                };
1004                retval = self.aug_solver.solve(
1005                    &coeffs,
1006                    &aug_rhs,
1007                    &mut aug_sol,
1008                    check_inertia,
1009                    num_neg_evals,
1010                );
1011            }
1012
1013            if retval == ESymSolverStatus::FatalError {
1014                return false;
1015            }
1016
1017            if retval == ESymSolverStatus::Singular && (rhs.y_c.dim() + rhs.y_d.dim() > 0) {
1018                let curr_mu = data.borrow().curr_mu;
1019                let next = self
1020                    .perturb
1021                    .borrow_mut()
1022                    .perturb_for_singular(curr_mu, Some(data));
1023                let Some(nd) = next else { return false };
1024                d = nd;
1025            } else if retval == ESymSolverStatus::WrongInertia
1026                && self.aug_solver.number_of_neg_evals() < num_neg_evals
1027            {
1028                let mut assume_singular = true;
1029                if !self.augsys_improved {
1030                    self.augsys_improved = self.aug_solver.increase_quality();
1031                    if self.augsys_improved {
1032                        data.borrow_mut().append_info_string("q");
1033                        assume_singular = false;
1034                    }
1035                }
1036                if assume_singular {
1037                    let curr_mu = data.borrow().curr_mu;
1038                    let next = self
1039                        .perturb
1040                        .borrow_mut()
1041                        .perturb_for_singular(curr_mu, Some(data));
1042                    let Some(nd) = next else { return false };
1043                    d = nd;
1044                    data.borrow_mut().append_info_string("a");
1045                }
1046            } else if retval == ESymSolverStatus::WrongInertia
1047                || retval == ESymSolverStatus::Singular
1048            {
1049                let curr_mu = data.borrow().curr_mu;
1050                let next = self
1051                    .perturb
1052                    .borrow_mut()
1053                    .perturb_for_wrong_inertia(curr_mu, Some(data));
1054                let Some(nd) = next else { return false };
1055                d = nd;
1056            }
1057
1058            if retval == ESymSolverStatus::Success {
1059                break;
1060            }
1061        }
1062        let _ = count;
1063
1064        // Stash the perturbation on data — upstream calls
1065        // `IpData().setPDPert(...)` here.
1066        {
1067            let mut dm = data.borrow_mut();
1068            dm.perturbations.delta_x = d.delta_x;
1069            dm.perturbations.delta_s = d.delta_s;
1070            dm.perturbations.delta_c = d.delta_c;
1071            dm.perturbations.delta_d = d.delta_d;
1072        }
1073
1074        // Mark this matrix as "considered" so subsequent `solve_once`
1075        // re-calls within the same outer `solve()` (iterative refinement
1076        // / quality retry) take the single-solve path above.
1077        self.matrix_considered = true;
1078
1079        expand_bound_multipliers(b, rhs, &mut sol);
1080
1081        // res = α · sol + β · res
1082        let frozen_sol = sol.freeze();
1083        res.add_one_vector(alpha, &frozen_sol, beta);
1084        true
1085    }
1086
1087    /// `resid = M · res − rhs` per `ComputeResiduals`. Skips terms
1088    /// whose perturbation is exactly zero.
1089    fn compute_residuals(
1090        &self,
1091        _data: &IpoptDataHandle,
1092        b: &SolveBlocks<'_>,
1093        rhs: &IteratesVector,
1094        res: &IteratesVectorMut,
1095        resid: &mut IteratesVectorMut,
1096    ) {
1097        let d = self.perturb.borrow().current_perturbation();
1098
1099        // x: W·res.x + J_c^T·res.y_c + J_d^T·res.y_d
1100        //    − Px_L·res.z_L + Px_U·res.z_U + δ_x·res.x − rhs.x
1101        b.w.mult_vector(1.0, &*res.x, 0.0, &mut *resid.x);
1102        b.j_c.trans_mult_vector(1.0, &*res.y_c, 1.0, &mut *resid.x);
1103        b.j_d.trans_mult_vector(1.0, &*res.y_d, 1.0, &mut *resid.x);
1104        b.px_l.mult_vector(-1.0, &*res.z_l, 1.0, &mut *resid.x);
1105        b.px_u.mult_vector(1.0, &*res.z_u, 1.0, &mut *resid.x);
1106        // resid.x += δ_x·res.x − rhs.x
1107        resid
1108            .x
1109            .add_two_vectors(d.delta_x, &*res.x, -1.0, &*rhs.x, 1.0);
1110
1111        // s: Pd_U·res.v_U − Pd_L·res.v_L − res.y_d − rhs.s + δ_s·res.s
1112        b.pd_u.mult_vector(1.0, &*res.v_u, 0.0, &mut *resid.s);
1113        b.pd_l.mult_vector(-1.0, &*res.v_l, 1.0, &mut *resid.s);
1114        resid.s.add_two_vectors(-1.0, &*res.y_d, -1.0, &*rhs.s, 1.0);
1115        if d.delta_s != 0.0 {
1116            resid.s.axpy(d.delta_s, &*res.s);
1117        }
1118
1119        // c: J_c·res.x − δ_c·res.y_c − rhs.y_c
1120        b.j_c.mult_vector(1.0, &*res.x, 0.0, &mut *resid.y_c);
1121        resid
1122            .y_c
1123            .add_two_vectors(-d.delta_c, &*res.y_c, -1.0, &*rhs.y_c, 1.0);
1124
1125        // d: J_d·res.x − res.s − rhs.y_d − δ_d·res.y_d
1126        b.j_d.mult_vector(1.0, &*res.x, 0.0, &mut *resid.y_d);
1127        resid
1128            .y_d
1129            .add_two_vectors(-1.0, &*res.s, -1.0, &*rhs.y_d, 1.0);
1130        if d.delta_d != 0.0 {
1131            resid.y_d.axpy(-d.delta_d, &*res.y_d);
1132        }
1133
1134        // zL: res.z_L · slack_x_L + (Px_L^T·res.x) · z_L − rhs.z_L
1135        resid.z_l.copy(&*res.z_l);
1136        resid.z_l.element_wise_multiply(b.slack_x_l);
1137        let mut tmp_zl = b.z_l.make_new();
1138        b.px_l.trans_mult_vector(1.0, &*res.x, 0.0, &mut *tmp_zl);
1139        tmp_zl.element_wise_multiply(b.z_l);
1140        resid
1141            .z_l
1142            .add_two_vectors(1.0, &*tmp_zl, -1.0, &*rhs.z_l, 1.0);
1143
1144        // zU: res.z_U · slack_x_U − (Px_U^T·res.x) · z_U − rhs.z_U
1145        resid.z_u.copy(&*res.z_u);
1146        resid.z_u.element_wise_multiply(b.slack_x_u);
1147        let mut tmp_zu = b.z_u.make_new();
1148        b.px_u.trans_mult_vector(1.0, &*res.x, 0.0, &mut *tmp_zu);
1149        tmp_zu.element_wise_multiply(b.z_u);
1150        resid
1151            .z_u
1152            .add_two_vectors(-1.0, &*tmp_zu, -1.0, &*rhs.z_u, 1.0);
1153
1154        // vL: res.v_L · slack_s_L + (Pd_L^T·res.s) · v_L − rhs.v_L
1155        resid.v_l.copy(&*res.v_l);
1156        resid.v_l.element_wise_multiply(b.slack_s_l);
1157        let mut tmp_vl = b.v_l.make_new();
1158        b.pd_l.trans_mult_vector(1.0, &*res.s, 0.0, &mut *tmp_vl);
1159        tmp_vl.element_wise_multiply(b.v_l);
1160        resid
1161            .v_l
1162            .add_two_vectors(1.0, &*tmp_vl, -1.0, &*rhs.v_l, 1.0);
1163
1164        // vU: res.v_U · slack_s_U − (Pd_U^T·res.s) · v_U − rhs.v_U
1165        resid.v_u.copy(&*res.v_u);
1166        resid.v_u.element_wise_multiply(b.slack_s_u);
1167        let mut tmp_vu = b.v_u.make_new();
1168        b.pd_u.trans_mult_vector(1.0, &*res.s, 0.0, &mut *tmp_vu);
1169        tmp_vu.element_wise_multiply(b.v_u);
1170        resid
1171            .v_u
1172            .add_two_vectors(-1.0, &*tmp_vu, -1.0, &*rhs.v_u, 1.0);
1173    }
1174
1175    /// `nrm_resid / (min(nrm_res, max_cond·nrm_rhs) + nrm_rhs)`, with
1176    /// `max_cond = 1e6`. Mirrors `ComputeResidualRatio`.
1177    fn compute_residual_ratio(
1178        &self,
1179        rhs: &IteratesVector,
1180        res: &IteratesVectorMut,
1181        resid: &IteratesVectorMut,
1182    ) -> Number {
1183        let nrm_rhs = rhs.amax();
1184        let nrm_res = res.amax();
1185        let nrm_resid = resid.amax();
1186        if nrm_rhs + nrm_res == 0.0 {
1187            nrm_resid
1188        } else {
1189            let max_cond = 1e6;
1190            nrm_resid / (nrm_res.min(max_cond * nrm_rhs) + nrm_rhs)
1191        }
1192    }
1193}
1194
1195impl PdSystemSolver for PdFullSpaceSolver {
1196    fn solve_status(&self) -> ESymSolverStatus {
1197        self.last_status.unwrap_or(ESymSolverStatus::FatalError)
1198    }
1199}
1200
1201/// Bag of borrowed blocks used by both `solve_once` and
1202/// `compute_residuals` — keeps argument lists tractable.
1203struct SolveBlocks<'a> {
1204    w: &'a dyn SymMatrix,
1205    j_c: &'a dyn Matrix,
1206    j_d: &'a dyn Matrix,
1207    px_l: &'a dyn Matrix,
1208    px_u: &'a dyn Matrix,
1209    pd_l: &'a dyn Matrix,
1210    pd_u: &'a dyn Matrix,
1211    z_l: &'a dyn Vector,
1212    z_u: &'a dyn Vector,
1213    v_l: &'a dyn Vector,
1214    v_u: &'a dyn Vector,
1215    slack_x_l: &'a dyn Vector,
1216    slack_x_u: &'a dyn Vector,
1217    slack_s_l: &'a dyn Vector,
1218    slack_s_u: &'a dyn Vector,
1219    sigma_x: &'a dyn Vector,
1220    sigma_s: &'a dyn Vector,
1221}
1222
1223/// Helper trait extension on `IteratesVectorMut` for fresh zeroed
1224/// allocations matching the same shape — the shape lives implicitly
1225/// in the existing components' `dim()`.
1226trait FreshZeroed {
1227    fn fresh_zeroed(&self) -> IteratesVectorMut;
1228}
1229
1230impl FreshZeroed for IteratesVectorMut {
1231    fn fresh_zeroed(&self) -> IteratesVectorMut {
1232        IteratesVectorMut {
1233            x: self.x.make_new(),
1234            s: self.s.make_new(),
1235            y_c: self.y_c.make_new(),
1236            y_d: self.y_d.make_new(),
1237            z_l: self.z_l.make_new(),
1238            z_u: self.z_u.make_new(),
1239            v_l: self.v_l.make_new(),
1240            v_u: self.v_u.make_new(),
1241        }
1242    }
1243}
1244
1245/// Snapshot a mutable iterate into a frozen, shareable copy without
1246/// consuming it. Used to remember `res_in` when β ≠ 0.
1247fn snapshot_mut(m: &IteratesVectorMut) -> IteratesVector {
1248    let mut out = m.fresh_zeroed();
1249    out.x.copy(&*m.x);
1250    out.s.copy(&*m.s);
1251    out.y_c.copy(&*m.y_c);
1252    out.y_d.copy(&*m.y_d);
1253    out.z_l.copy(&*m.z_l);
1254    out.z_u.copy(&*m.z_u);
1255    out.v_l.copy(&*m.v_l);
1256    out.v_u.copy(&*m.v_u);
1257    out.freeze()
1258}
1259
1260/// Convert a frozen `IteratesVector` back to a mutable owned form.
1261/// Allocates fresh storage and copies; the iterative-refinement loop
1262/// re-freezes/thaws once per iteration, so a single per-component
1263/// copy is acceptable.
1264/// Expand the four bound-multiplier blocks of `sol` from the just-
1265/// computed primal-step blocks (`sol.x`, `sol.s`):
1266///
1267/// ```text
1268/// sol.z_L = S_xL⁻¹ · (rhs.z_L − z_L · (Px_L^T · sol.x))
1269/// sol.z_U = S_xU⁻¹ · (rhs.z_U + z_U · (Px_U^T · sol.x))
1270/// sol.v_L = S_sL⁻¹ · (rhs.v_L − v_L · (Pd_L^T · sol.s))
1271/// sol.v_U = S_sU⁻¹ · (rhs.v_U + v_U · (Pd_U^T · sol.s))
1272/// ```
1273///
1274/// Encoded via `SinvBlrmZMTdBr` with `α = ±1`. Mirrors the bound-
1275/// multiplier expansion at the bottom of upstream's
1276/// `IpPDFullSpaceSolver::SolveOnce`.
1277fn expand_bound_multipliers(
1278    b: &SolveBlocks<'_>,
1279    rhs: &IteratesVector,
1280    sol: &mut IteratesVectorMut,
1281) {
1282    b.px_l
1283        .sinv_blrm_zmt_dbr(-1.0, b.slack_x_l, &*rhs.z_l, b.z_l, &*sol.x, &mut *sol.z_l);
1284    b.px_u
1285        .sinv_blrm_zmt_dbr(1.0, b.slack_x_u, &*rhs.z_u, b.z_u, &*sol.x, &mut *sol.z_u);
1286    b.pd_l
1287        .sinv_blrm_zmt_dbr(-1.0, b.slack_s_l, &*rhs.v_l, b.v_l, &*sol.s, &mut *sol.v_l);
1288    b.pd_u
1289        .sinv_blrm_zmt_dbr(1.0, b.slack_s_u, &*rhs.v_u, b.v_u, &*sol.s, &mut *sol.v_u);
1290}
1291
1292/// Copy a `DenseVector`'s materialized values into `dst`. Used by
1293/// `solve_many_cached` to pack the aug-system RHS into a column of the
1294/// flat `aug_packed` buffer.
1295fn copy_vector_to_slice(src: &dyn Vector, dst: &mut [Number]) {
1296    if dst.is_empty() {
1297        return;
1298    }
1299    let dv = src
1300        .as_any()
1301        .downcast_ref::<pounce_linalg::dense_vector::DenseVector>()
1302        .expect("solve_many_cached requires DenseVector blocks");
1303    if dv.is_homogeneous() {
1304        let v = dv.scalar();
1305        dst.iter_mut().for_each(|x| *x = v);
1306    } else {
1307        dst.copy_from_slice(dv.values());
1308    }
1309}
1310
1311/// Inverse of [`copy_vector_to_slice`]: write `src` into a
1312/// `DenseVector` in place.
1313fn set_vector_from_slice(dst: &mut dyn Vector, src: &[Number]) {
1314    if src.is_empty() {
1315        return;
1316    }
1317    let dv = dst
1318        .as_any_mut()
1319        .downcast_mut::<pounce_linalg::dense_vector::DenseVector>()
1320        .expect("solve_many_cached requires DenseVector blocks");
1321    dv.set_values(src);
1322}
1323
1324/// Downcast a `dyn Vector` block to its concrete `DenseVector` slice.
1325/// Returns `None` if the block is not a `DenseVector`, or if the block
1326/// is non-empty but stored as a homogeneous scalar (the
1327/// `solve_many_cached_flat` fast path needs a real slice for its
1328/// inline scatter loops; the closure-based fallback handles
1329/// homogeneous-on-non-empty correctly via `add_m_sinv_z` / `sinv_blrm_zmt_dbr`).
1330fn dense_slice_or_none(v: &dyn Vector, expected_dim: usize) -> Option<&[Number]> {
1331    if expected_dim == 0 {
1332        // An empty block doesn't need a slice — the scatter loops below
1333        // simply don't iterate when exp_pos is empty. Return an empty
1334        // slice so the caller can pass it through unconditionally.
1335        return Some(&[]);
1336    }
1337    let dv = v.as_any().downcast_ref::<DenseVector>()?;
1338    if dv.is_homogeneous() {
1339        return None;
1340    }
1341    Some(dv.values())
1342}
1343
1344/// Downcast a `dyn Matrix` to its concrete `ExpansionMatrix`'s
1345/// expanded-position index slice. Returns `None` if the matrix is not
1346/// an `ExpansionMatrix`.
1347fn exp_pos_or_none(m: &dyn Matrix) -> Option<&[Index]> {
1348    let em = m.as_any().downcast_ref::<ExpansionMatrix>()?;
1349    Some(em.expanded_pos_indices())
1350}
1351
1352/// Phase-1 inner kernel: `out[exp_pos[i]] += alpha · src[i] / denom[i]`.
1353/// Hot loop in `solve_many_cached_flat`. Specialised on `alpha = ±1`
1354/// to skip the multiply.
1355#[inline]
1356fn scatter_add_div(
1357    out: &mut [Number],
1358    exp_pos: &[Index],
1359    src: &[Number],
1360    denom: &[Number],
1361    alpha: Number,
1362) {
1363    if exp_pos.is_empty() {
1364        return;
1365    }
1366    debug_assert_eq!(src.len(), exp_pos.len());
1367    debug_assert_eq!(denom.len(), exp_pos.len());
1368    if alpha == 1.0 {
1369        for i in 0..exp_pos.len() {
1370            out[exp_pos[i] as usize] += src[i] / denom[i];
1371        }
1372    } else if alpha == -1.0 {
1373        for i in 0..exp_pos.len() {
1374            out[exp_pos[i] as usize] -= src[i] / denom[i];
1375        }
1376    } else {
1377        for i in 0..exp_pos.len() {
1378            out[exp_pos[i] as usize] += alpha * src[i] / denom[i];
1379        }
1380    }
1381}
1382
1383/// Phase-3 inner kernel: bound-multiplier expansion,
1384/// `out[i] = (r[i] + alpha · z[i] · sol[exp_pos[i]]) / s[i]`.
1385/// Mirrors `ExpansionMatrix::sinv_blrm_zmt_dbr_impl` (the non-
1386/// homogeneous specialisation) inlined against raw slices.
1387#[inline]
1388#[allow(clippy::too_many_arguments)]
1389fn expand_bound_mult(
1390    out: &mut [Number],
1391    r: &[Number],
1392    z: &[Number],
1393    sol: &[Number],
1394    exp_pos: &[Index],
1395    s: &[Number],
1396    alpha: Number,
1397) {
1398    if exp_pos.is_empty() {
1399        return;
1400    }
1401    debug_assert_eq!(out.len(), exp_pos.len());
1402    debug_assert_eq!(r.len(), exp_pos.len());
1403    debug_assert_eq!(z.len(), exp_pos.len());
1404    debug_assert_eq!(s.len(), exp_pos.len());
1405    if alpha == 1.0 {
1406        for i in 0..exp_pos.len() {
1407            out[i] = (r[i] + z[i] * sol[exp_pos[i] as usize]) / s[i];
1408        }
1409    } else if alpha == -1.0 {
1410        for i in 0..exp_pos.len() {
1411            out[i] = (r[i] - z[i] * sol[exp_pos[i] as usize]) / s[i];
1412        }
1413    } else {
1414        for i in 0..exp_pos.len() {
1415            out[i] = (r[i] + alpha * z[i] * sol[exp_pos[i] as usize]) / s[i];
1416        }
1417    }
1418}
1419
1420fn thaw(iv: IteratesVector) -> IteratesVectorMut {
1421    fn one(v: Rc<dyn Vector>) -> Box<dyn Vector> {
1422        let mut b = v.make_new();
1423        b.copy(&*v);
1424        b
1425    }
1426    IteratesVectorMut {
1427        x: one(iv.x),
1428        s: one(iv.s),
1429        y_c: one(iv.y_c),
1430        y_d: one(iv.y_d),
1431        z_l: one(iv.z_l),
1432        z_u: one(iv.z_u),
1433        v_l: one(iv.v_l),
1434        v_u: one(iv.v_u),
1435    }
1436}
1437
1438/// Internal placeholder used only inside [`PdFullSpaceSolver::wrap_aug_solver`]
1439/// to satisfy `std::mem::replace`'s requirement for a value of the same
1440/// type while the real boxed solver is being moved through the wrapper
1441/// closure. None of the trait methods are ever invoked.
1442struct NoopAugSolver;
1443
1444impl AugSystemSolver for NoopAugSolver {
1445    fn provides_inertia(&self) -> bool {
1446        unreachable!("NoopAugSolver is a transient placeholder")
1447    }
1448    fn number_of_neg_evals(&self) -> Index {
1449        unreachable!("NoopAugSolver is a transient placeholder")
1450    }
1451    fn increase_quality(&mut self) -> bool {
1452        unreachable!("NoopAugSolver is a transient placeholder")
1453    }
1454    fn last_solve_status(&self) -> ESymSolverStatus {
1455        unreachable!("NoopAugSolver is a transient placeholder")
1456    }
1457    fn solve(
1458        &mut self,
1459        _coeffs: &AugSysCoeffs<'_>,
1460        _rhs: &AugSysRhs<'_>,
1461        _sol: &mut AugSysSol<'_>,
1462        _check_neg_evals: bool,
1463        _num_neg_evals: Index,
1464    ) -> ESymSolverStatus {
1465        unreachable!("NoopAugSolver is a transient placeholder")
1466    }
1467}