pounce_sensitivity/diff_handoff.rs
1//! The `solve → DiffHandoff` contract — the solver-agnostic bundle that
2//! every differentiable solve hands to its backward pass.
3//!
4//! Design: `dev-notes/diff-handoff-contract.md`. The motivation is that
5//! POUNCE differentiates solves across several frontends (JAX / PyTorch,
6//! NLP / QP) and each was re-deriving the same *active-set* facts —
7//! "a bound is active when its multiplier exceeds a tolerance", "an
8//! equality row is always active", "active-bound / fixed (e.g. integer)
9//! variables are pinned, `dx/dp = 0`". This struct computes those facts
10//! **once**, in the producer, so every consumer reads them instead of
11//! recomputing `|mult| > tol` under its own tolerance.
12//!
13//! This module is intentionally small and dependency-light: it is plain
14//! data plus the one active-set derivation. It does *not* own the KKT
15//! factor's linear algebra — that stays in [`crate::solver`] /
16//! [`crate::PdSensBacksolver`]; a `DiffHandoff` produced from a live
17//! solve carries the converged solution and duals, and the factor is
18//! reached through the owning [`crate::Solver`] / [`ConvergedState`].
19//!
20//! It introduces no branch-and-bound and references no downstream
21//! consumer: the test for belonging here is "would any differentiable
22//! layer want it?" — and every one does.
23
24use pounce_common::types::{Index, Number};
25
26use crate::convenience::SensResult;
27
28/// Default activity tolerance: a constraint or bound multiplier with
29/// magnitude above this is treated as active. Matches the `_ACTIVE_TOL`
30/// long used by the Python JAX/torch backward passes
31/// (`python/pounce/jax/_diff.py`), centralized here so there is one
32/// documented knob rather than one per frontend.
33pub const DEFAULT_ACTIVE_TOL: Number = 1e-6;
34
35/// Everything the implicit-function-theorem backward pass needs from a
36/// converged solve, in a solver-agnostic shape.
37///
38/// Producers (IPM-NLP, convex LP/QP, conic, and — for discopt — the
39/// fixed-integer leaf of a branch-and-bound) emit this; consumers
40/// (`pounce.jax`, `pounce.torch`, the C ABI, a future Rust autodiff
41/// user, discopt across the `solve_nlp` seam) differentiate from it.
42///
43/// The multiplier sign / length conventions match the existing C ABI and
44/// Python `info` dict (`mult_g`, `mult_x_L`, `mult_x_U`), so this is a
45/// re-shape of data POUNCE already returns — not a new computation — plus
46/// the precomputed active-set masks, which are the genuinely new part.
47#[derive(Debug, Clone)]
48pub struct DiffHandoff {
49 // ---- primal / dual solution ----
50 /// Final primal iterate `x*` (length `n_x`).
51 pub x: Vec<Number>,
52 /// Objective value `f(x*)`.
53 pub obj_val: Number,
54 /// General-constraint multipliers `λ` (length `n_g`). The `g`/`G`/`A`
55 /// duals, depending on the solver; one name across all of them.
56 pub lambda: Vec<Number>,
57 /// Variable lower-bound multipliers `z_L` (length `n_x`).
58 pub mult_x_lower: Vec<Number>,
59 /// Variable upper-bound multipliers `z_U` (length `n_x`).
60 pub mult_x_upper: Vec<Number>,
61
62 // ---- active set, computed ONCE here ----
63 /// Constraint rows in the differentiated KKT block: equalities
64 /// (always) plus inequalities whose `|λ| > active_tol`. Length `n_g`.
65 /// Inactive (slack) rows drop out of the backward block.
66 pub active_constraints: Vec<bool>,
67 /// Variables pinned in the backward (`dx/dp = 0`): those with an
68 /// active bound (`max(z_L, z_U) > active_tol`) and — for a B&B leaf —
69 /// integer variables fixed at the optimum (see [`Self::pin`]).
70 /// Length `n_x`.
71 pub pinned_vars: Vec<bool>,
72 /// The activity tolerance used to derive the masks above. Recorded so
73 /// consumers and tests see the exact threshold.
74 pub active_tol: Number,
75}
76
77impl DiffHandoff {
78 /// Build a handoff from the raw converged solution and duals,
79 /// deriving the active-set masks with `active_tol`.
80 ///
81 /// `equality_mask[i]` is `true` when constraint `i` is an equality
82 /// (`g_l[i] == g_u[i]`) — such rows are always active. Pass an empty
83 /// slice when there are no general constraints.
84 pub fn from_solution(
85 x: Vec<Number>,
86 obj_val: Number,
87 lambda: Vec<Number>,
88 mult_x_lower: Vec<Number>,
89 mult_x_upper: Vec<Number>,
90 equality_mask: &[bool],
91 active_tol: Number,
92 ) -> Self {
93 debug_assert_eq!(mult_x_lower.len(), x.len(), "z_L length must match x");
94 debug_assert_eq!(mult_x_upper.len(), x.len(), "z_U length must match x");
95 let (pinned_vars, active_constraints) = Self::masks(
96 &mult_x_lower,
97 &mult_x_upper,
98 &lambda,
99 equality_mask,
100 active_tol,
101 );
102 Self {
103 x,
104 obj_val,
105 lambda,
106 mult_x_lower,
107 mult_x_upper,
108 active_constraints,
109 pinned_vars,
110 active_tol,
111 }
112 }
113
114 /// Derive the active-set masks `(pinned_vars, active_constraints)` from
115 /// borrowed duals — the single active-set derivation, shared by
116 /// [`Self::from_solution`] and producers that want only the masks (e.g.
117 /// the Python `info` dict) without surrendering the solution vectors.
118 /// Keeping the rule here means "`|mult| > active_tol`, equalities always
119 /// active" lives in exactly one place.
120 ///
121 /// `pinned_vars[i]` is `true` when variable `i`'s lower- or upper-bound
122 /// multiplier exceeds `active_tol`. `active_constraints[i]` is `true` for
123 /// an equality row (`equality_mask[i]`) or one whose `|lambda[i]| >
124 /// active_tol`. `equality_mask` may be empty (no equalities known) or
125 /// length `lambda.len()`.
126 pub fn masks(
127 mult_x_lower: &[Number],
128 mult_x_upper: &[Number],
129 lambda: &[Number],
130 equality_mask: &[bool],
131 active_tol: Number,
132 ) -> (Vec<bool>, Vec<bool>) {
133 debug_assert_eq!(
134 mult_x_lower.len(),
135 mult_x_upper.len(),
136 "z_L and z_U lengths must match"
137 );
138 debug_assert!(
139 equality_mask.is_empty() || equality_mask.len() == lambda.len(),
140 "equality_mask must be empty or length n_g"
141 );
142 // A bound is active when either side's multiplier exceeds the
143 // tolerance → the variable is pinned (dx/dp = 0).
144 let pinned_vars = mult_x_lower
145 .iter()
146 .zip(mult_x_upper.iter())
147 .map(|(&l, &u)| l > active_tol || u > active_tol)
148 .collect();
149 // A constraint row is active when it is an equality (always) or its
150 // multiplier magnitude exceeds the tolerance.
151 let active_constraints = lambda
152 .iter()
153 .enumerate()
154 .map(|(i, &lam)| {
155 equality_mask.get(i).copied().unwrap_or(false) || lam.abs() > active_tol
156 })
157 .collect();
158 (pinned_vars, active_constraints)
159 }
160
161 /// Re-shape a [`SensResult`] from a converged solve into a
162 /// `DiffHandoff`, using [`DEFAULT_ACTIVE_TOL`].
163 ///
164 /// Returns `None` when the solve did not populate the duals
165 /// (`mult_g` / `mult_x_l` / `mult_x_u`) — i.e. it didn't converge, or
166 /// the NLP didn't expose user-space multipliers.
167 ///
168 /// `equality_mask` is the caller's `g_l[i] == g_u[i]` test, length
169 /// `n_g`. **Pass the real mask whenever the problem has equality
170 /// constraints.** Equality rows are *always* part of the differentiated
171 /// KKT block regardless of multiplier magnitude, and the mask is the
172 /// only way `from_sens_result` learns which rows those are — a
173 /// [`SensResult`] carries the constraint *values* (`g`) but not their
174 /// `[g_l, g_u]` bounds, so equalities can't be recovered from it.
175 ///
176 /// An empty slice means "no equality information": a row then counts as
177 /// active only when `|λ| > active_tol`. That is correct **only** when
178 /// the problem has no equalities. ⚠ With equalities present it silently
179 /// drops any *degenerate* equality — one whose multiplier is ≈ 0
180 /// (redundant rows, or an equality not binding the optimum's curvature)
181 /// — from the active set, yielding a wrong backward block and wrong
182 /// gradients. Dropping a row is the *unsafe* direction, so the empty
183 /// slice is a convenience for the no-equality case, not a safe default.
184 pub fn from_sens_result(res: &SensResult, equality_mask: &[bool]) -> Option<Self> {
185 let x = res.x.clone()?;
186 let obj_val = res.obj_val?;
187 let lambda = res.mult_g.clone()?;
188 let mult_x_lower = res.mult_x_l.clone()?;
189 let mult_x_upper = res.mult_x_u.clone()?;
190 Some(Self::from_solution(
191 x,
192 obj_val,
193 lambda,
194 mult_x_lower,
195 mult_x_upper,
196 equality_mask,
197 DEFAULT_ACTIVE_TOL,
198 ))
199 }
200
201 /// Additionally pin a set of variables — the seam discopt uses for a
202 /// branch-and-bound leaf: integer variables fixed at the optimum
203 /// differentiate exactly like active bounds (`dx/dp = 0`). Indices
204 /// out of range are ignored.
205 pub fn pin(&mut self, indices: &[Index]) {
206 for &i in indices {
207 if i < 0 {
208 continue;
209 }
210 if let Some(slot) = self.pinned_vars.get_mut(i as usize) {
211 *slot = true;
212 }
213 }
214 }
215
216 /// Number of primal variables.
217 pub fn n_x(&self) -> usize {
218 self.x.len()
219 }
220
221 /// Number of general constraints.
222 pub fn n_g(&self) -> usize {
223 self.lambda.len()
224 }
225
226 /// Count of pinned variables (active bounds + any [`Self::pin`]ned).
227 pub fn n_pinned(&self) -> usize {
228 self.pinned_vars.iter().filter(|&&b| b).count()
229 }
230
231 /// Count of active constraint rows.
232 pub fn n_active_constraints(&self) -> usize {
233 self.active_constraints.iter().filter(|&&b| b).count()
234 }
235}
236
237#[cfg(test)]
238mod tests {
239 use super::*;
240 use pounce_nlp::return_codes::ApplicationReturnStatus;
241
242 #[test]
243 fn from_sens_result_degenerate_equality_needs_the_mask() {
244 // One equality constraint (g_l == g_u) whose multiplier is ≈ 0 at
245 // the solution — a degenerate / redundant equality. Equalities are
246 // always active, so it belongs in the backward block; but the
247 // empty-mask path can't know it's an equality and (wrongly) drops
248 // it. Pin BOTH behaviors so the hazard documented on
249 // `from_sens_result` is explicit and tested, not silent.
250 let res = SensResult {
251 status: ApplicationReturnStatus::SolveSucceeded,
252 error: None,
253 x: Some(vec![1.0]),
254 obj_val: Some(0.0),
255 dx: None,
256 dx_full: None,
257 reduced_hessian: None,
258 reduced_hessian_scaled: None,
259 obj_scaling_factor: None,
260 pin_g_scaling: None,
261 kkt_perturbations: None,
262 reduced_hessian_eigenvalues: None,
263 reduced_hessian_eigenvectors: None,
264 mult_g: Some(vec![0.0]), // degenerate: |λ| ≈ 0
265 mult_x_l: Some(vec![0.0]),
266 mult_x_u: Some(vec![0.0]),
267 g: Some(vec![0.0]),
268 };
269
270 // Empty mask: no equality info → the degenerate equality is dropped.
271 let dropped = DiffHandoff::from_sens_result(&res, &[]).unwrap();
272 assert_eq!(dropped.active_constraints, vec![false]);
273
274 // Correct mask: the equality stays active regardless of |λ|.
275 let kept = DiffHandoff::from_sens_result(&res, &[true]).unwrap();
276 assert_eq!(kept.active_constraints, vec![true]);
277 }
278
279 #[test]
280 fn from_sens_result_returns_none_without_duals() {
281 let res = SensResult {
282 status: ApplicationReturnStatus::SolveSucceeded,
283 error: None,
284 x: Some(vec![1.0]),
285 obj_val: Some(0.0),
286 dx: None,
287 dx_full: None,
288 reduced_hessian: None,
289 reduced_hessian_scaled: None,
290 obj_scaling_factor: None,
291 pin_g_scaling: None,
292 kkt_perturbations: None,
293 reduced_hessian_eigenvalues: None,
294 reduced_hessian_eigenvectors: None,
295 mult_g: None, // duals not populated → no handoff
296 mult_x_l: None,
297 mult_x_u: None,
298 g: None,
299 };
300 assert!(DiffHandoff::from_sens_result(&res, &[]).is_none());
301 }
302
303 #[test]
304 fn pins_active_bounds_and_marks_active_constraints() {
305 // x0: lower bound active (z_L large). x1: free. x2: upper active.
306 let x = vec![0.0, 1.0, 2.0];
307 let z_l = vec![5.0, 0.0, 0.0];
308 let z_u = vec![0.0, 0.0, 3.0];
309 // g0: equality. g1: inactive inequality (λ≈0). g2: active inequality.
310 let lambda = vec![0.0, 1e-9, 4.0];
311 let eq = vec![true, false, false];
312
313 let h = DiffHandoff::from_solution(x, 42.0, lambda, z_l, z_u, &eq, DEFAULT_ACTIVE_TOL);
314
315 assert_eq!(h.pinned_vars, vec![true, false, true]);
316 assert_eq!(h.active_constraints, vec![true, false, true]);
317 assert_eq!(h.n_pinned(), 2);
318 assert_eq!(h.n_active_constraints(), 2);
319 assert_eq!(h.obj_val, 42.0);
320 }
321
322 #[test]
323 fn empty_equality_mask_treats_only_nonzero_rows_as_active() {
324 let h = DiffHandoff::from_solution(
325 vec![0.0],
326 0.0,
327 vec![0.0, 5.0],
328 vec![0.0],
329 vec![0.0],
330 &[],
331 DEFAULT_ACTIVE_TOL,
332 );
333 assert_eq!(h.active_constraints, vec![false, true]);
334 }
335
336 #[test]
337 fn pin_adds_integer_variables() {
338 let mut h = DiffHandoff::from_solution(
339 vec![0.0, 0.0, 0.0],
340 0.0,
341 vec![],
342 vec![0.0, 0.0, 0.0],
343 vec![0.0, 0.0, 0.0],
344 &[],
345 DEFAULT_ACTIVE_TOL,
346 );
347 assert_eq!(h.n_pinned(), 0);
348 h.pin(&[1, 99]); // 99 is out of range, ignored
349 assert_eq!(h.pinned_vars, vec![false, true, false]);
350 assert_eq!(h.n_pinned(), 1);
351 }
352}