Skip to main content

pounce_sensitivity/
solver.rs

1//! `Solver` — value-typed session API that holds an `IpoptApplication`,
2//! its TNLP, and the converged KKT factor between calls.
3//!
4//! This is Phase 3a of the factor-reuse work tracked in
5//! [pounce#16](https://github.com/jkitchin/pounce/issues/16). It is
6//! the public surface for callers who want to:
7//!
8//! 1. Run a normal IPM solve, then
9//! 2. Issue many cheap operations against the converged factor
10//!    (`kkt_solve`, `parametric_step`) without going through the
11//!    [`set_on_converged`] callback shape that [`crate::SensSolve`]
12//!    requires.
13//!
14//! [`set_on_converged`]: pounce_algorithm::IpoptApplication::set_on_converged
15//!
16//! # Usage
17//!
18//! ```ignore
19//! use pounce_sensitivity::Solver;
20//! use std::cell::RefCell;
21//! use std::rc::Rc;
22//!
23//! let app = make_configured_app();
24//! let tnlp: Rc<RefCell<dyn TNLP>> = Rc::new(RefCell::new(MyTnlp));
25//! let mut solver = Solver::new(app, tnlp);
26//!
27//! let status = solver.solve();
28//! assert!(solver.converged().is_some());
29//!
30//! // Issue any number of back-solves against the same factor:
31//! let dim = solver.kkt_dim().unwrap();
32//! let mut lhs = vec![0.0; dim];
33//! let rhs = vec![1.0; dim];
34//! solver.kkt_solve(&rhs, &mut lhs).unwrap();
35//!
36//! // Parametric step with respect to a set of pinned equality
37//! // constraints (same interpretation as [`crate::SensSolve`]):
38//! let dx = solver.parametric_step(&[2, 3], &[-0.5, 0.0]).unwrap();
39//! ```
40//!
41//! # Scope of Phase 3a
42//!
43//! - **In**: `solve()`, `converged()`, `kkt_solve()`, `parametric_step()`,
44//!   `block_dims()` / `kkt_dim()`.
45//! - **Deferred to Phase 3b**: `resolve()` (warm-start that reuses the
46//!   linear backend pool), `compute_reduced_hessian()` on the Solver
47//!   (currently only available through [`crate::SensSolve`]), and the
48//!   `parametric_mpc` / `sensitivity_session` example binaries.
49
50use std::cell::{Ref, RefCell};
51use std::rc::Rc;
52
53use pounce_algorithm::application::IpoptApplication;
54use pounce_common::types::{Index, Number};
55use pounce_nlp::return_codes::ApplicationReturnStatus;
56use pounce_nlp::TNLP;
57
58use crate::backsolver::SensBacksolver;
59use crate::schur_data::IndexSchurData;
60use crate::sens_app::{SensApplication, SensOptions};
61use crate::PdSensBacksolver;
62
63/// Errors returned by post-convergence operations on [`Solver`].
64#[derive(Debug, Clone)]
65pub enum SolverError {
66    /// The solver has not yet converged, or the last solve failed
67    /// before producing a usable KKT factor.
68    NotConverged,
69    /// An input slice's length did not match the KKT dimension or the
70    /// parameter count.
71    BadShape {
72        /// Human description of the mismatched buffer.
73        what: &'static str,
74        /// Length the caller passed.
75        got: usize,
76        /// Length expected.
77        expected: usize,
78    },
79    /// The underlying back-solve failed (singular factor, numerical
80    /// breakdown).
81    BacksolveFailed,
82    /// The underlying [`SensApplication`] step failed (e.g. row mapping
83    /// invalid for the current problem).
84    SensComputationFailed(String),
85}
86
87/// State captured at convergence: the user-visible iterate plus the
88/// `PdSensBacksolver` that wraps the converged KKT factor.
89///
90/// Read this via [`Solver::converged`].
91pub struct ConvergedState {
92    /// IPM return status of the most recent solve.
93    pub status: ApplicationReturnStatus,
94    /// Final primal iterate `x*` (length `n_x`).
95    pub x: Vec<Number>,
96    /// Final objective value `f(x*)`.
97    pub obj_val: Number,
98    /// Converged KKT-factor wrapper. Owns `Rc` handles to the
99    /// `PdFullSpaceSolver`, the IpoptData / Cq, and the NLP, so it
100    /// outlives the IPM call frame.
101    backsolver: PdSensBacksolver,
102}
103
104impl ConvergedState {
105    /// Block dimensions of the compound KKT vector in
106    /// `(x, s, y_c, y_d, z_l, z_u, v_l, v_u)` order.
107    pub fn block_dims(&self) -> [usize; 8] {
108        self.backsolver.block_dims()
109    }
110
111    /// Total dimension of the compound KKT vector (sum of `block_dims`).
112    pub fn kkt_dim(&self) -> usize {
113        self.backsolver.dim()
114    }
115}
116
117/// Session-style solver: holds an [`IpoptApplication`], its TNLP, and
118/// the converged factor between calls.
119pub struct Solver {
120    app: IpoptApplication,
121    tnlp: Rc<RefCell<dyn TNLP>>,
122    /// Side channel populated by the `on_converged` callback installed
123    /// in [`Self::solve`]. The `RefCell<Option<…>>` shape mirrors the
124    /// pattern in [`crate::convenience`] (the callback closure needs
125    /// shared mutable access; the `Option` is `None` before the first
126    /// solve and gets overwritten on each call).
127    state: Rc<RefCell<Option<ConvergedState>>>,
128}
129
130impl Solver {
131    /// Build a new session. The `app` should already have its options
132    /// configured and `initialize()` called.
133    pub fn new(app: IpoptApplication, tnlp: Rc<RefCell<dyn TNLP>>) -> Self {
134        Self {
135            app,
136            tnlp,
137            state: Rc::new(RefCell::new(None)),
138        }
139    }
140
141    /// Borrow the underlying `IpoptApplication` (e.g. to read its
142    /// options table after a solve). Mutation between `solve` calls is
143    /// supported via [`Self::app_mut`].
144    pub fn app(&self) -> &IpoptApplication {
145        &self.app
146    }
147
148    /// Mutable borrow of the underlying `IpoptApplication`. Useful for
149    /// reconfiguring options before a follow-up `solve()`. Note that
150    /// changing options that affect the KKT linear system between
151    /// calls will invalidate the cached factor; the next `solve()`
152    /// rebuilds it.
153    pub fn app_mut(&mut self) -> &mut IpoptApplication {
154        &mut self.app
155    }
156
157    /// Run the IPM to convergence. On a successful solve the
158    /// [`ConvergedState`] (including the KKT backsolver) is stashed
159    /// inside the `Solver` and accessible via [`Self::converged`].
160    ///
161    /// Each call to `solve()` overwrites the previous converged
162    /// state; the previously held factor is dropped.
163    pub fn solve(&mut self) -> ApplicationReturnStatus {
164        // Clear any previous state so a failed re-solve doesn't leave
165        // a stale factor visible.
166        self.state.borrow_mut().take();
167
168        let state_cb = Rc::clone(&self.state);
169        self.app
170            .set_on_converged(Box::new(move |data, cq, nlp, pd| {
171                let curr = match data.borrow().curr.clone() {
172                    Some(c) => c,
173                    None => return,
174                };
175                let backsolver = match PdSensBacksolver::new(data, cq, nlp, Rc::clone(&pd)) {
176                    Ok(b) => b,
177                    Err(_) => return,
178                };
179                let x = dense_to_vec(&*curr.x);
180                let obj_val = cq.borrow_mut().curr_f();
181                // Status is overwritten with the real value after
182                // optimize_tnlp returns.
183                *state_cb.borrow_mut() = Some(ConvergedState {
184                    status: ApplicationReturnStatus::InternalError,
185                    x,
186                    obj_val,
187                    backsolver,
188                });
189            }));
190
191        let status = self.app.optimize_tnlp(Rc::clone(&self.tnlp));
192        if let Some(s) = self.state.borrow_mut().as_mut() {
193            s.status = status;
194        }
195        status
196    }
197
198    /// Borrow the converged state, if a successful solve has been
199    /// run. Returns `None` if no solve has run or if the most recent
200    /// solve failed before reaching convergence.
201    pub fn converged(&self) -> Option<Ref<'_, ConvergedState>> {
202        let r = self.state.borrow();
203        r.as_ref()?;
204        Some(Ref::map(r, |o| {
205            o.as_ref()
206                .unwrap_or_else(|| unreachable!("checked is_some above"))
207        }))
208    }
209
210    /// Total dimension of the compound KKT vector (sum of
211    /// `block_dims`). Returns `None` if no converged factor is held.
212    pub fn kkt_dim(&self) -> Option<usize> {
213        self.converged().map(|c| c.kkt_dim())
214    }
215
216    /// Block dimensions of the compound KKT vector in
217    /// `(x, s, y_c, y_d, z_l, z_u, v_l, v_u)` order. Returns `None` if
218    /// no converged factor is held.
219    pub fn block_dims(&self) -> Option<[usize; 8]> {
220        self.converged().map(|c| c.block_dims())
221    }
222
223    /// Solve `K · lhs = rhs` against the converged KKT factor. Both
224    /// slices must have length `kkt_dim()`; the layout is the flat
225    /// `x || s || y_c || y_d || z_l || z_u || v_l || v_u` packing.
226    pub fn kkt_solve(&self, rhs: &[Number], lhs: &mut [Number]) -> Result<(), SolverError> {
227        let state = self.state.borrow();
228        let state = state.as_ref().ok_or(SolverError::NotConverged)?;
229        let total = state.backsolver.dim();
230        if rhs.len() != total {
231            return Err(SolverError::BadShape {
232                what: "rhs",
233                got: rhs.len(),
234                expected: total,
235            });
236        }
237        if lhs.len() != total {
238            return Err(SolverError::BadShape {
239                what: "lhs",
240                got: lhs.len(),
241                expected: total,
242            });
243        }
244        if state.backsolver.solve(rhs, lhs) {
245            Ok(())
246        } else {
247            Err(SolverError::BacksolveFailed)
248        }
249    }
250
251    /// Batched-RHS back-solve. `rhs_flat` and `lhs_flat` are row-major
252    /// `(n_rhs, kkt_dim)` buffers; each row is solved against the
253    /// same converged factor. Equivalent in result to looping
254    /// [`Self::kkt_solve`] but reuses one `IteratesVector` for the
255    /// RHS and one for the result across all `n_rhs` calls — see
256    /// [`crate::algorithm_backsolver::PdSensBacksolver::solve_many`].
257    pub fn kkt_solve_many(
258        &self,
259        rhs_flat: &[Number],
260        lhs_flat: &mut [Number],
261        n_rhs: usize,
262    ) -> Result<(), SolverError> {
263        let state = self.state.borrow();
264        let state = state.as_ref().ok_or(SolverError::NotConverged)?;
265        let total = state.backsolver.dim();
266        let expected = n_rhs * total;
267        if rhs_flat.len() != expected {
268            return Err(SolverError::BadShape {
269                what: "rhs",
270                got: rhs_flat.len(),
271                expected,
272            });
273        }
274        if lhs_flat.len() != expected {
275            return Err(SolverError::BadShape {
276                what: "lhs",
277                got: lhs_flat.len(),
278                expected,
279            });
280        }
281        if state.backsolver.solve_many(rhs_flat, lhs_flat, n_rhs) {
282            Ok(())
283        } else {
284            Err(SolverError::BacksolveFailed)
285        }
286    }
287
288    /// First-order parametric step `Δx ≈ ∂x*/∂p · Δp` for a set of
289    /// pinned equality constraints. `pin_constraint_indices` are
290    /// 0-based indices into the user's `g(x)`; `deltas` is the
291    /// perturbation `Δp` (same length).
292    ///
293    /// Returns the `n_x`-long primal step. For the full KKT-space
294    /// step, use [`Self::kkt_solve`] directly.
295    pub fn parametric_step(
296        &self,
297        pin_constraint_indices: &[Index],
298        deltas: &[Number],
299    ) -> Result<Vec<Number>, SolverError> {
300        if pin_constraint_indices.len() != deltas.len() {
301            return Err(SolverError::BadShape {
302                what: "deltas",
303                got: deltas.len(),
304                expected: pin_constraint_indices.len(),
305            });
306        }
307        let state = self.state.borrow();
308        let state = state.as_ref().ok_or(SolverError::NotConverged)?;
309
310        // y_c rows live right after the (x, s) primal block in the
311        // compound-vector layout (matches `convenience.rs`).
312        let dims = state.backsolver.block_dims();
313        let n_x = dims[0];
314        let n_s = dims[1];
315        let y_c_offset = (n_x + n_s) as Index;
316        let param_rows: Vec<Index> = pin_constraint_indices
317            .iter()
318            .map(|&i| y_c_offset + i)
319            .collect();
320        let signs = vec![1; pin_constraint_indices.len()];
321        let a_data = IndexSchurData::from_parts(param_rows, signs)
322            .map_err(|e| SolverError::SensComputationFailed(format!("{e:?}")))?;
323
324        let opts = SensOptions {
325            run_sens: true,
326            ..SensOptions::default()
327        };
328        let sens_app = SensApplication::new(a_data, state.backsolver.clone(), opts);
329        let n_full = state.backsolver.dim();
330        let mut dx_full = vec![0.0; n_full];
331        if !sens_app.parametric_step(deltas, &mut dx_full) {
332            return Err(SolverError::SensComputationFailed(
333                "SensApplication::parametric_step failed".into(),
334            ));
335        }
336        dx_full.truncate(n_x);
337        Ok(dx_full)
338    }
339
340    /// Reduced Hessian `H_R = obj_scal · B K⁻¹ Bᵀ` over the pinned
341    /// equality-constraint rows, where `B` selects the
342    /// `pin_constraint_indices` rows of the y_c block. Returns the
343    /// `n²`-long column-major dense matrix (`n = pin_constraint_indices.len()`).
344    ///
345    /// Equivalent to [`crate::SensSolve::with_reduced_hessian`] but
346    /// usable post-hoc on a held `Solver`.
347    pub fn compute_reduced_hessian(
348        &self,
349        pin_constraint_indices: &[Index],
350        obj_scal: Number,
351    ) -> Result<Vec<Number>, SolverError> {
352        let state = self.state.borrow();
353        let state = state.as_ref().ok_or(SolverError::NotConverged)?;
354        let n = pin_constraint_indices.len();
355        let dims = state.backsolver.block_dims();
356        let y_c_offset = (dims[0] + dims[1]) as Index;
357        let param_rows: Vec<Index> = pin_constraint_indices
358            .iter()
359            .map(|&i| y_c_offset + i)
360            .collect();
361        let signs = vec![1; n];
362        let a_data = IndexSchurData::from_parts(param_rows, signs)
363            .map_err(|e| SolverError::SensComputationFailed(format!("{e:?}")))?;
364        let opts = SensOptions {
365            compute_red_hessian: true,
366            obj_scal,
367            ..SensOptions::default()
368        };
369        let mut sens_app = SensApplication::new(a_data, state.backsolver.clone(), opts);
370        let mut hr = vec![0.0; n * n];
371        if !sens_app.compute_reduced_hessian(&mut hr) {
372            return Err(SolverError::SensComputationFailed(
373                "SensApplication::compute_reduced_hessian failed".into(),
374            ));
375        }
376        Ok(hr)
377    }
378}
379
380fn dense_to_vec(v: &dyn pounce_linalg::Vector) -> Vec<Number> {
381    match v
382        .as_any()
383        .downcast_ref::<pounce_linalg::dense_vector::DenseVector>()
384    {
385        Some(d) => d.values().to_vec(),
386        None => vec![0.0; v.dim() as usize],
387    }
388}