Skip to main content

pounce_sensitivity/
step_calc.rs

1//! `SensStepCalc` trait — orchestrates the sensitivity step computation.
2//!
3//! Given a converged KKT iterate and a parameter perturbation
4//! `Δp`, sIPOPT's step calc produces the first-order primal/dual
5//! sensitivity step `(Δx, Δλ, Δz)`. The math factors through two
6//! linear systems:
7//!
8//! 1. **Schur solve**: `S · Δu = −Bᵀ Δp`, with `S` factored by
9//!    [`crate::schur_driver::SchurDriver`].
10//! 2. **Augmented backsolve**: `K · Δx = −A · Δu`, via the
11//!    [`crate::SensBacksolver`].
12//!
13//! Reference: Pirnay, López-Negrete & Biegler 2012, §3 (DOI:
14//! [10.1007/s12532-012-0043-2](https://doi.org/10.1007/s12532-012-0043-2)).
15//! Upstream impl:
16//! [`SensStdStepCalc.{hpp,cpp}`](../../../ref/Ipopt/contrib/sIPOPT/src/SensStdStepCalc.cpp).
17//!
18//! # Phase B.1 scope
19//!
20//! This file ships the trait + a minimal `StdStepCalc` that exercises
21//! the two-step pipeline using the same `SensBacksolver` instance for
22//! both the Schur build and the inner augmented backsolve. The
23//! algorithm-side wiring that produces the actual `Δp` source vector
24//! from a TNLP's parameter-perturbation slot lands in Phase B.2.
25
26use crate::backsolver::SensBacksolver;
27use crate::p_calculator::PCalculator;
28use crate::schur_driver::SchurDriver;
29use pounce_common::types::Number;
30
31/// Compute a sensitivity step `Δu = S⁻¹ · rhs_u` (Schur-space) and
32/// `Δx_full = K⁻¹ · A · Δu` (backsolved KKT-space), where the
33/// `rhs_u` vector encodes the parameter perturbation.
34///
35/// Mirrors upstream `SensStepCalc` (abstract,
36/// [`SensStepCalc.hpp`](../../../ref/Ipopt/contrib/sIPOPT/src/SensStepCalc.hpp))
37/// + `SensStdStepCalc` (concrete, the only implementation upstream
38/// ships).
39pub trait SensStepCalc {
40    /// Run the two-step sensitivity computation. Outputs:
41    /// - `du`: length `n_b`, the Schur-space step.
42    /// - `dx_full`: length `n_full` (the backsolver's dimension),
43    ///   the full primal/dual step `K⁻¹ A · du`. Implementations
44    ///   may apply the upstream sign convention internally (the
45    ///   exact sign depends on which side of the augmented system
46    ///   the perturbation enters; see the upstream reference for
47    ///   the parametric-flavor convention).
48    ///
49    /// Returns `false` if the Schur driver hasn't been built /
50    /// factored, the backsolver fails, or the buffers are
51    /// mis-sized.
52    fn compute_step(&self, rhs_u: &[Number], du: &mut [Number], dx_full: &mut [Number]) -> bool;
53}
54
55/// Reference implementation that strings together
56/// [`SchurDriver::schur_solve`] and a final
57/// [`SensBacksolver::solve`] using the `A` data the driver was
58/// built with.
59///
60/// Mirrors upstream
61/// [`SensStdStepCalc.cpp:23-282`](../../../ref/Ipopt/contrib/sIPOPT/src/SensStdStepCalc.cpp).
62///
63/// The Schur-driver type is borrowed for the lifetime of the
64/// `StdStepCalc`; this matches upstream's "build once, step many"
65/// pattern.
66pub struct StdStepCalc<'d, D: SchurDriver + WithBacksolver, P: PCalculator> {
67    driver: &'d D,
68    /// Borrow of the same `PCalculator` that built the driver, used
69    /// to expose the `A` data for the final `A · du` scatter without
70    /// requiring the caller to thread it through separately.
71    pcalc: &'d P,
72}
73
74impl<'d, D: SchurDriver + WithBacksolver, P: PCalculator> StdStepCalc<'d, D, P> {
75    /// Construct from references to the driver and the matching
76    /// pcalc. Both must already be in the post-factor state.
77    pub fn new(driver: &'d D, pcalc: &'d P) -> Self {
78        Self { driver, pcalc }
79    }
80}
81
82/// Bridge trait — exposes a `SensBacksolver`-shaped solve through
83/// whatever the driver wraps. Implementations of `SchurDriver` that
84/// want to be consumed by `StdStepCalc` opt in by also implementing
85/// `WithBacksolver`. This keeps `SchurDriver`'s own surface
86/// minimal — most drivers don't need to expose the inner backsolver.
87pub trait WithBacksolver {
88    /// Apply `K⁻¹ · rhs` and write the result into `out`. Returns
89    /// `false` if the inner backsolver fails or buffers don't match.
90    fn k_solve(&self, rhs: &[Number], out: &mut [Number]) -> bool;
91}
92
93impl<'d, D, P> SensStepCalc for StdStepCalc<'d, D, P>
94where
95    D: SchurDriver + WithBacksolver,
96    P: PCalculator,
97{
98    fn compute_step(&self, rhs_u: &[Number], du: &mut [Number], dx_full: &mut [Number]) -> bool {
99        // 1. Schur step: solve S · du = rhs_u.
100        if !self.driver.schur_solve(rhs_u, du) {
101            return false;
102        }
103        // 2. Construct the KKT-side rhs: A · du. The trans_multiply
104        //    method on a SchurData computes Bᵀ u, so calling
105        //    A.trans_multiply(du, scratch) gives Aᵀ Bᵀ = … wait, A
106        //    is the row-space matrix, so multiply by du to get the
107        //    full-state rhs. The pounce convention: each row of A
108        //    selects a single full-state component, so
109        //    `A.trans_multiply(du, rhs)` scatters `du[i]` into
110        //    `rhs[A_idx[i]]`.
111        let a = self.pcalc.data_a();
112        let n_full = dx_full.len();
113        let mut rhs_full = vec![0.0; n_full];
114        if let Err(_) = a.trans_multiply(du, &mut rhs_full) {
115            return false;
116        }
117        // 3. Backsolve K · dx_full = rhs_full.
118        self.driver.k_solve(&rhs_full, dx_full)
119    }
120}
121
122/// Convenience impl: a `DenseGenSchurDriver` parametrized over an
123/// `IndexPCalculator<B>` can hand off its inner backsolver via
124/// this bridge.
125impl<B> WithBacksolver
126    for crate::schur_driver::DenseGenSchurDriver<crate::p_calculator::IndexPCalculator<B>, B>
127where
128    B: SensBacksolver,
129{
130    fn k_solve(&self, rhs: &[Number], out: &mut [Number]) -> bool {
131        // The IndexPCalculator owns the backsolver; reach through.
132        self.pcalc().backsolver().solve(rhs, out)
133    }
134}
135
136#[cfg(test)]
137mod tests {
138    use super::*;
139    use crate::backsolver::DenseLuBacksolver;
140    use crate::p_calculator::IndexPCalculator;
141    use crate::schur_data::IndexSchurData;
142    use crate::schur_driver::DenseGenSchurDriver;
143
144    /// End-to-end sensitivity step on the 3×3 SPD K + 2-row B = A
145    /// setup that other tests reuse.
146    ///
147    /// S = -K⁻¹ restricted to rows/cols {0, 2}
148    ///   = -[[3/4, 1/4], [1/4, 3/4]]
149    ///
150    /// With rhs_u = (1, 0):
151    ///   S · du = (1, 0) ⇒ du = (-3/2, 1/2)  (verified in the
152    ///   schur_driver test).
153    ///
154    /// `A · du` = `e_0 · du[0] + e_2 · du[1]` lifted to full-x:
155    ///   rhs_full = (-3/2, 0, 1/2).
156    /// K⁻¹ · rhs_full = K⁻¹ · (-3/2, 0, 1/2)
157    ///   K⁻¹ = 1/4 · [[3, 2, 1], [2, 4, 2], [1, 2, 3]]
158    ///   K⁻¹ · (-3/2, 0, 1/2) = 1/4 · (3·(-3/2)+1·(1/2),
159    ///                                 2·(-3/2)+2·(1/2),
160    ///                                 1·(-3/2)+3·(1/2))
161    ///                        = 1/4 · (-4, -2, 0)
162    ///                        = (-1, -1/2, 0).
163    #[test]
164    fn std_step_calc_runs_two_step_pipeline() {
165        #[rustfmt::skip]
166        let k = vec![
167             2.0, -1.0,  0.0,
168            -1.0,  2.0, -1.0,
169             0.0, -1.0,  2.0,
170        ];
171        let backsolver = DenseLuBacksolver::from_dense(3, &k).unwrap();
172        let a = IndexSchurData::from_parts(vec![0, 2], vec![1, 1]).unwrap();
173        let pc = IndexPCalculator::new(backsolver, a);
174        let mut driver = DenseGenSchurDriver::<_, DenseLuBacksolver>::new(pc);
175        let b = IndexSchurData::from_parts(vec![0, 2], vec![1, 1]).unwrap();
176        assert!(driver.schur_build_and_factor(&b));
177
178        let step = StdStepCalc::new(&driver, driver.pcalc());
179        let rhs_u = [1.0, 0.0];
180        let mut du = [0.0; 2];
181        let mut dx = [0.0; 3];
182        assert!(step.compute_step(&rhs_u, &mut du, &mut dx));
183
184        // du = (-3/2, 1/2)
185        assert!((du[0] - (-1.5)).abs() < 1e-10, "du[0] = {}", du[0]);
186        assert!((du[1] - 0.5).abs() < 1e-10, "du[1] = {}", du[1]);
187
188        // dx = (-1, -1/2, 0)
189        assert!((dx[0] - (-1.0)).abs() < 1e-10, "dx[0] = {}", dx[0]);
190        assert!((dx[1] - (-0.5)).abs() < 1e-10, "dx[1] = {}", dx[1]);
191        assert!((dx[2] - 0.0).abs() < 1e-10, "dx[2] = {}", dx[2]);
192    }
193}