Skip to main content

pounce_sensitivity/
diff_handoff.rs

1//! The `solve → DiffHandoff` contract — the solver-agnostic bundle that
2//! every differentiable solve hands to its backward pass.
3//!
4//! Design: `dev-notes/diff-handoff-contract.md`. The motivation is that
5//! POUNCE differentiates solves across several frontends (JAX / PyTorch,
6//! NLP / QP) and each was re-deriving the same *active-set* facts —
7//! "a bound is active when its multiplier exceeds a tolerance", "an
8//! equality row is always active", "active-bound / fixed (e.g. integer)
9//! variables are pinned, `dx/dp = 0`". This struct computes those facts
10//! **once**, in the producer, so every consumer reads them instead of
11//! recomputing `|mult| > tol` under its own tolerance.
12//!
13//! This module is intentionally small and dependency-light: it is plain
14//! data plus the one active-set derivation. It does *not* own the KKT
15//! factor's linear algebra — that stays in [`crate::solver`] /
16//! [`crate::PdSensBacksolver`]; a `DiffHandoff` produced from a live
17//! solve carries the converged solution and duals, and the factor is
18//! reached through the owning [`crate::Solver`] / [`ConvergedState`].
19//!
20//! It introduces no branch-and-bound and references no downstream
21//! consumer: the test for belonging here is "would any differentiable
22//! layer want it?" — and every one does.
23
24use pounce_common::types::{Index, Number};
25
26use crate::convenience::SensResult;
27
28/// Default activity tolerance: a constraint or bound multiplier with
29/// magnitude above this is treated as active. Matches the `_ACTIVE_TOL`
30/// long used by the Python JAX/torch backward passes
31/// (`python/pounce/jax/_diff.py`), centralized here so there is one
32/// documented knob rather than one per frontend.
33pub const DEFAULT_ACTIVE_TOL: Number = 1e-6;
34
35/// Everything the implicit-function-theorem backward pass needs from a
36/// converged solve, in a solver-agnostic shape.
37///
38/// Producers (IPM-NLP, convex LP/QP, conic, and — for discopt — the
39/// fixed-integer leaf of a branch-and-bound) emit this; consumers
40/// (`pounce.jax`, `pounce.torch`, the C ABI, a future Rust autodiff
41/// user, discopt across the `solve_nlp` seam) differentiate from it.
42///
43/// The multiplier sign / length conventions match the existing C ABI and
44/// Python `info` dict (`mult_g`, `mult_x_L`, `mult_x_U`), so this is a
45/// re-shape of data POUNCE already returns — not a new computation — plus
46/// the precomputed active-set masks, which are the genuinely new part.
47#[derive(Debug, Clone)]
48pub struct DiffHandoff {
49    // ---- primal / dual solution ----
50    /// Final primal iterate `x*` (length `n_x`).
51    pub x: Vec<Number>,
52    /// Objective value `f(x*)`.
53    pub obj_val: Number,
54    /// General-constraint multipliers `λ` (length `n_g`). The `g`/`G`/`A`
55    /// duals, depending on the solver; one name across all of them.
56    pub lambda: Vec<Number>,
57    /// Variable lower-bound multipliers `z_L` (length `n_x`).
58    pub mult_x_lower: Vec<Number>,
59    /// Variable upper-bound multipliers `z_U` (length `n_x`).
60    pub mult_x_upper: Vec<Number>,
61
62    // ---- active set, computed ONCE here ----
63    /// Constraint rows in the differentiated KKT block: equalities
64    /// (always) plus inequalities whose `|λ| > active_tol`. Length `n_g`.
65    /// Inactive (slack) rows drop out of the backward block.
66    pub active_constraints: Vec<bool>,
67    /// Variables pinned in the backward (`dx/dp = 0`): those with an
68    /// active bound (`max(z_L, z_U) > active_tol`) and — for a B&B leaf —
69    /// integer variables fixed at the optimum (see [`Self::pin`]).
70    /// Length `n_x`.
71    pub pinned_vars: Vec<bool>,
72    /// The activity tolerance used to derive the masks above. Recorded so
73    /// consumers and tests see the exact threshold.
74    pub active_tol: Number,
75}
76
77impl DiffHandoff {
78    /// Build a handoff from the raw converged solution and duals,
79    /// deriving the active-set masks with `active_tol`.
80    ///
81    /// `equality_mask[i]` is `true` when constraint `i` is an equality
82    /// (`g_l[i] == g_u[i]`) — such rows are always active. Pass an empty
83    /// slice when there are no general constraints.
84    pub fn from_solution(
85        x: Vec<Number>,
86        obj_val: Number,
87        lambda: Vec<Number>,
88        mult_x_lower: Vec<Number>,
89        mult_x_upper: Vec<Number>,
90        equality_mask: &[bool],
91        active_tol: Number,
92    ) -> Self {
93        debug_assert_eq!(mult_x_lower.len(), x.len(), "z_L length must match x");
94        debug_assert_eq!(mult_x_upper.len(), x.len(), "z_U length must match x");
95        let (pinned_vars, active_constraints) = Self::masks(
96            &mult_x_lower,
97            &mult_x_upper,
98            &lambda,
99            equality_mask,
100            active_tol,
101        );
102        Self {
103            x,
104            obj_val,
105            lambda,
106            mult_x_lower,
107            mult_x_upper,
108            active_constraints,
109            pinned_vars,
110            active_tol,
111        }
112    }
113
114    /// Derive the active-set masks `(pinned_vars, active_constraints)` from
115    /// borrowed duals — the single active-set derivation, shared by
116    /// [`Self::from_solution`] and producers that want only the masks (e.g.
117    /// the Python `info` dict) without surrendering the solution vectors.
118    /// Keeping the rule here means "`|mult| > active_tol`, equalities always
119    /// active" lives in exactly one place.
120    ///
121    /// `pinned_vars[i]` is `true` when variable `i`'s lower- or upper-bound
122    /// multiplier exceeds `active_tol`. `active_constraints[i]` is `true` for
123    /// an equality row (`equality_mask[i]`) or one whose `|lambda[i]| >
124    /// active_tol`. `equality_mask` may be empty (no equalities known) or
125    /// length `lambda.len()`.
126    pub fn masks(
127        mult_x_lower: &[Number],
128        mult_x_upper: &[Number],
129        lambda: &[Number],
130        equality_mask: &[bool],
131        active_tol: Number,
132    ) -> (Vec<bool>, Vec<bool>) {
133        debug_assert_eq!(
134            mult_x_lower.len(),
135            mult_x_upper.len(),
136            "z_L and z_U lengths must match"
137        );
138        debug_assert!(
139            equality_mask.is_empty() || equality_mask.len() == lambda.len(),
140            "equality_mask must be empty or length n_g"
141        );
142        // A bound is active when either side's multiplier exceeds the
143        // tolerance → the variable is pinned (dx/dp = 0).
144        let pinned_vars = mult_x_lower
145            .iter()
146            .zip(mult_x_upper.iter())
147            .map(|(&l, &u)| l > active_tol || u > active_tol)
148            .collect();
149        // A constraint row is active when it is an equality (always) or its
150        // multiplier magnitude exceeds the tolerance.
151        let active_constraints = lambda
152            .iter()
153            .enumerate()
154            .map(|(i, &lam)| {
155                equality_mask.get(i).copied().unwrap_or(false) || lam.abs() > active_tol
156            })
157            .collect();
158        (pinned_vars, active_constraints)
159    }
160
161    /// Re-shape a [`SensResult`] from a converged solve into a
162    /// `DiffHandoff`, using [`DEFAULT_ACTIVE_TOL`].
163    ///
164    /// Returns `None` when the solve did not populate the duals
165    /// (`mult_g` / `mult_x_l` / `mult_x_u`) — i.e. it didn't converge, or
166    /// the NLP didn't expose user-space multipliers.
167    ///
168    /// `equality_mask` is the caller's `g_l[i] == g_u[i]` test, length
169    /// `n_g`. **Pass the real mask whenever the problem has equality
170    /// constraints.** Equality rows are *always* part of the differentiated
171    /// KKT block regardless of multiplier magnitude, and the mask is the
172    /// only way `from_sens_result` learns which rows those are — a
173    /// [`SensResult`] carries the constraint *values* (`g`) but not their
174    /// `[g_l, g_u]` bounds, so equalities can't be recovered from it.
175    ///
176    /// An empty slice means "no equality information": a row then counts as
177    /// active only when `|λ| > active_tol`. That is correct **only** when
178    /// the problem has no equalities. ⚠ With equalities present it silently
179    /// drops any *degenerate* equality — one whose multiplier is ≈ 0
180    /// (redundant rows, or an equality not binding the optimum's curvature)
181    /// — from the active set, yielding a wrong backward block and wrong
182    /// gradients. Dropping a row is the *unsafe* direction, so the empty
183    /// slice is a convenience for the no-equality case, not a safe default.
184    pub fn from_sens_result(res: &SensResult, equality_mask: &[bool]) -> Option<Self> {
185        let x = res.x.clone()?;
186        let obj_val = res.obj_val?;
187        let lambda = res.mult_g.clone()?;
188        let mult_x_lower = res.mult_x_l.clone()?;
189        let mult_x_upper = res.mult_x_u.clone()?;
190        Some(Self::from_solution(
191            x,
192            obj_val,
193            lambda,
194            mult_x_lower,
195            mult_x_upper,
196            equality_mask,
197            DEFAULT_ACTIVE_TOL,
198        ))
199    }
200
201    /// Additionally pin a set of variables — the seam discopt uses for a
202    /// branch-and-bound leaf: integer variables fixed at the optimum
203    /// differentiate exactly like active bounds (`dx/dp = 0`). Indices
204    /// out of range are ignored.
205    pub fn pin(&mut self, indices: &[Index]) {
206        for &i in indices {
207            if i < 0 {
208                continue;
209            }
210            if let Some(slot) = self.pinned_vars.get_mut(i as usize) {
211                *slot = true;
212            }
213        }
214    }
215
216    /// Number of primal variables.
217    pub fn n_x(&self) -> usize {
218        self.x.len()
219    }
220
221    /// Number of general constraints.
222    pub fn n_g(&self) -> usize {
223        self.lambda.len()
224    }
225
226    /// Count of pinned variables (active bounds + any [`Self::pin`]ned).
227    pub fn n_pinned(&self) -> usize {
228        self.pinned_vars.iter().filter(|&&b| b).count()
229    }
230
231    /// Count of active constraint rows.
232    pub fn n_active_constraints(&self) -> usize {
233        self.active_constraints.iter().filter(|&&b| b).count()
234    }
235}
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240    use pounce_nlp::return_codes::ApplicationReturnStatus;
241
242    #[test]
243    fn from_sens_result_degenerate_equality_needs_the_mask() {
244        // One equality constraint (g_l == g_u) whose multiplier is ≈ 0 at
245        // the solution — a degenerate / redundant equality. Equalities are
246        // always active, so it belongs in the backward block; but the
247        // empty-mask path can't know it's an equality and (wrongly) drops
248        // it. Pin BOTH behaviors so the hazard documented on
249        // `from_sens_result` is explicit and tested, not silent.
250        let res = SensResult {
251            status: ApplicationReturnStatus::SolveSucceeded,
252            error: None,
253            x: Some(vec![1.0]),
254            obj_val: Some(0.0),
255            dx: None,
256            dx_full: None,
257            reduced_hessian: None,
258            reduced_hessian_scaled: None,
259            obj_scaling_factor: None,
260            pin_g_scaling: None,
261            kkt_perturbations: None,
262            reduced_hessian_eigenvalues: None,
263            reduced_hessian_eigenvectors: None,
264            mult_g: Some(vec![0.0]), // degenerate: |λ| ≈ 0
265            mult_x_l: Some(vec![0.0]),
266            mult_x_u: Some(vec![0.0]),
267            g: Some(vec![0.0]),
268        };
269
270        // Empty mask: no equality info → the degenerate equality is dropped.
271        let dropped = DiffHandoff::from_sens_result(&res, &[]).unwrap();
272        assert_eq!(dropped.active_constraints, vec![false]);
273
274        // Correct mask: the equality stays active regardless of |λ|.
275        let kept = DiffHandoff::from_sens_result(&res, &[true]).unwrap();
276        assert_eq!(kept.active_constraints, vec![true]);
277    }
278
279    #[test]
280    fn from_sens_result_returns_none_without_duals() {
281        let res = SensResult {
282            status: ApplicationReturnStatus::SolveSucceeded,
283            error: None,
284            x: Some(vec![1.0]),
285            obj_val: Some(0.0),
286            dx: None,
287            dx_full: None,
288            reduced_hessian: None,
289            reduced_hessian_scaled: None,
290            obj_scaling_factor: None,
291            pin_g_scaling: None,
292            kkt_perturbations: None,
293            reduced_hessian_eigenvalues: None,
294            reduced_hessian_eigenvectors: None,
295            mult_g: None, // duals not populated → no handoff
296            mult_x_l: None,
297            mult_x_u: None,
298            g: None,
299        };
300        assert!(DiffHandoff::from_sens_result(&res, &[]).is_none());
301    }
302
303    #[test]
304    fn pins_active_bounds_and_marks_active_constraints() {
305        // x0: lower bound active (z_L large). x1: free. x2: upper active.
306        let x = vec![0.0, 1.0, 2.0];
307        let z_l = vec![5.0, 0.0, 0.0];
308        let z_u = vec![0.0, 0.0, 3.0];
309        // g0: equality. g1: inactive inequality (λ≈0). g2: active inequality.
310        let lambda = vec![0.0, 1e-9, 4.0];
311        let eq = vec![true, false, false];
312
313        let h = DiffHandoff::from_solution(x, 42.0, lambda, z_l, z_u, &eq, DEFAULT_ACTIVE_TOL);
314
315        assert_eq!(h.pinned_vars, vec![true, false, true]);
316        assert_eq!(h.active_constraints, vec![true, false, true]);
317        assert_eq!(h.n_pinned(), 2);
318        assert_eq!(h.n_active_constraints(), 2);
319        assert_eq!(h.obj_val, 42.0);
320    }
321
322    #[test]
323    fn empty_equality_mask_treats_only_nonzero_rows_as_active() {
324        let h = DiffHandoff::from_solution(
325            vec![0.0],
326            0.0,
327            vec![0.0, 5.0],
328            vec![0.0],
329            vec![0.0],
330            &[],
331            DEFAULT_ACTIVE_TOL,
332        );
333        assert_eq!(h.active_constraints, vec![false, true]);
334    }
335
336    #[test]
337    fn pin_adds_integer_variables() {
338        let mut h = DiffHandoff::from_solution(
339            vec![0.0, 0.0, 0.0],
340            0.0,
341            vec![],
342            vec![0.0, 0.0, 0.0],
343            vec![0.0, 0.0, 0.0],
344            &[],
345            DEFAULT_ACTIVE_TOL,
346        );
347        assert_eq!(h.n_pinned(), 0);
348        h.pin(&[1, 99]); // 99 is out of range, ignored
349        assert_eq!(h.pinned_vars, vec![false, true, false]);
350        assert_eq!(h.n_pinned(), 1);
351    }
352}