pounce_sensitivity/schur_data.rs
1//! `SchurData` trait surface and the `IndexSchurData` flavor.
2//!
3//! Direct port of upstream
4//! [`SensSchurData.hpp`](../../../ref/Ipopt/contrib/sIPOPT/src/SensSchurData.hpp)
5//! (interface) and
6//! [`SensIndexSchurData.{hpp,cpp}`](../../../ref/Ipopt/contrib/sIPOPT/src/SensIndexSchurData.hpp)
7//! (the index-only specialization).
8//!
9//! # What `SchurData` represents
10//!
11//! `SchurData` is the matrix `B` in the augmented system
12//!
13//! ```text
14//! ⎡ K A ⎤
15//! ⎣ B 0 ⎦
16//! ```
17//!
18//! used by the sIPOPT step calculation (Pirnay, López-Negrete & Biegler 2012,
19//! §2, eq. 4). `K` is the converged KKT matrix from the original IPM solve;
20//! `A` and `B` carry the parameter-perturbation rows. The trait is the
21//! minimum surface every backend needs to expose so the
22//! `PCalculator` / `SchurDriver` family can stay matrix-shape-agnostic
23//! ([`SensSchurData.hpp:17-25`](../../../ref/Ipopt/contrib/sIPOPT/src/SensSchurData.hpp)).
24//!
25//! # `IndexSchurData`
26//!
27//! Specialization for parameter rows whose only non-zero entries are
28//! ±1 ([`SensIndexSchurData.hpp:15-19`](../../../ref/Ipopt/contrib/sIPOPT/src/SensIndexSchurData.hpp)).
29//! Most parametric / reduced-Hessian use cases fit this shape — the
30//! parameter just picks out a subset of primal/dual variables — so the
31//! ±1 sparsification is what production sIPOPT runs on. Pounce's
32//! port stores parallel `Vec<i32>` arrays of indices and signs, same
33//! as upstream's `idx_` / `val_`
34//! ([`SensIndexSchurData.hpp:127-128`](../../../ref/Ipopt/contrib/sIPOPT/src/SensIndexSchurData.hpp)).
35
36use pounce_common::types::{Index, Number};
37
38/// Minimum surface for any matrix that lives in the augmented sIPOPT
39/// system's `A` / `B` slots. The numerical drivers in this crate
40/// (`PCalculator`, `SchurDriver`, `SensStepCalc`) consume `SchurData`
41/// objects and never touch the storage shape directly.
42///
43/// Mirrors `Ipopt::SchurData` from
44/// [`SensSchurData.hpp:29-178`](../../../ref/Ipopt/contrib/sIPOPT/src/SensSchurData.hpp).
45///
46/// # Lifecycle
47///
48/// A `SchurData` instance starts uninitialized. Exactly one of the
49/// `set_*` methods is called to populate it; subsequent reads via
50/// `nrows`, `multiply`, `trans_multiply`, etc. require an
51/// initialized instance. Upstream enforces this via `Set_Initialized`
52/// asserts in DBG builds
53/// ([`SensIndexSchurData.cpp:59-60`](../../../ref/Ipopt/contrib/sIPOPT/src/SensIndexSchurData.cpp));
54/// pounce mirrors that invariant via `Result` returns on the read
55/// surface so a mis-ordered call surfaces as `Err(SchurDataError::NotInitialized)`
56/// instead of a panic.
57pub trait SchurData {
58 /// Number of rows the schur matrix has, i.e. the row count of `B`.
59 /// Upstream `GetNRowsAdded()`
60 /// ([`SensSchurData.hpp:77-80`](../../../ref/Ipopt/contrib/sIPOPT/src/SensSchurData.hpp)).
61 fn nrows(&self) -> Index;
62
63 /// `true` if one of the `set_*` methods has been called and the
64 /// rest of the surface is safe to call. Upstream `Is_Initialized()`
65 /// ([`SensSchurData.hpp:82-85`](../../../ref/Ipopt/contrib/sIPOPT/src/SensSchurData.hpp)).
66 fn is_initialized(&self) -> bool;
67
68 /// Set rows from a 0/1 flag array of length `dim`. For each `i`
69 /// with `flags[i] == 1`, add a row whose only non-zero column is
70 /// `i` with sign `sign(v)`. The magnitude of `v` is collapsed to
71 /// ±1 — this trait only ever stores signs, mirroring upstream's
72 /// `SetData_Flag(dim, flags, v)`
73 /// ([`SensSchurData.hpp:45-49`](../../../ref/Ipopt/contrib/sIPOPT/src/SensSchurData.hpp),
74 /// [`SensIndexSchurData.cpp:51-78`](../../../ref/Ipopt/contrib/sIPOPT/src/SensIndexSchurData.cpp)).
75 ///
76 /// Returns `Err(SchurDataError::AlreadyInitialized)` if the instance
77 /// was already initialized, or `Err(SchurDataError::InvalidFlag)` if any
78 /// `flags[i]` is not 0/1. On either error the instance is left
79 /// unchanged, so a corrected retry is safe.
80 fn set_from_flags(&mut self, flags: &[Index], v: Number) -> Result<(), SchurDataError>;
81
82 /// Set rows from a list of column indices. Each `cols[k]` becomes
83 /// row `k` of `B`, with the single non-zero entry at column
84 /// `cols[k]` carrying sign `sign(v)`. Upstream `SetData_List`
85 /// ([`SensSchurData.hpp:64-67`](../../../ref/Ipopt/contrib/sIPOPT/src/SensSchurData.hpp),
86 /// [`SensIndexSchurData.cpp:149-167`](../../../ref/Ipopt/contrib/sIPOPT/src/SensIndexSchurData.cpp)).
87 ///
88 /// Returns `Err(_)` if already initialized.
89 fn set_from_list(&mut self, cols: &[Index], v: Number) -> Result<(), SchurDataError>;
90
91 /// Row-`i` access: return the parallel arrays `(indices, factors)`
92 /// such that row `i` of `B` has non-zero column entries
93 /// `factors[j]` at columns `indices[j]`. For `IndexSchurData`
94 /// `indices` has length 1 and `factors[0] == ±1.0`. Upstream
95 /// `GetMultiplyingVectors`
96 /// ([`SensSchurData.hpp:93-104`](../../../ref/Ipopt/contrib/sIPOPT/src/SensSchurData.hpp),
97 /// [`SensIndexSchurData.cpp:199-212`](../../../ref/Ipopt/contrib/sIPOPT/src/SensIndexSchurData.cpp)).
98 fn multiplying_row(&self, i: Index) -> Result<(Vec<Index>, Vec<Number>), SchurDataError>;
99
100 /// Apply `u = B v` for a `v` of length `n_full` and pre-sized
101 /// `u` buffer (length must equal `self.nrows()`). Upstream
102 /// `Multiply(IteratesVector, Vector)`
103 /// ([`SensSchurData.hpp:106-110`](../../../ref/Ipopt/contrib/sIPOPT/src/SensSchurData.hpp),
104 /// [`SensIndexSchurData.cpp:214-251`](../../../ref/Ipopt/contrib/sIPOPT/src/SensIndexSchurData.cpp)).
105 ///
106 /// In pounce we operate on flat `&[Number]` instead of upstream's
107 /// `IteratesVector` block layout because Phase-A `SchurData` is
108 /// shape-agnostic. Phase-B reconstructs the block layout where
109 /// the algorithm-side needs it.
110 fn multiply(&self, v: &[Number], u: &mut [Number]) -> Result<(), SchurDataError>;
111
112 /// Apply `v = Bᵀ u` for a `u` of length `self.nrows()` and a
113 /// pre-sized `v` buffer (length `n_full`). Upstream
114 /// `TransMultiply`
115 /// ([`SensSchurData.hpp:112-116`](../../../ref/Ipopt/contrib/sIPOPT/src/SensSchurData.hpp),
116 /// [`SensIndexSchurData.cpp:253-307`](../../../ref/Ipopt/contrib/sIPOPT/src/SensIndexSchurData.cpp)).
117 fn trans_multiply(&self, u: &[Number], v: &mut [Number]) -> Result<(), SchurDataError>;
118}
119
120/// Failure modes returned by [`SchurData`] read/write entry points.
121/// Pounce returns these as `Err(_)` where upstream's debug asserts
122/// would `DBG_ASSERT`.
123#[derive(Debug, Clone, Copy, PartialEq, Eq)]
124pub enum SchurDataError {
125 /// A read method was called before a `set_*` method initialized
126 /// the instance. Upstream asserts `Is_Initialized()` in DBG builds
127 /// (e.g. [`SensIndexSchurData.cpp:176`](../../../ref/Ipopt/contrib/sIPOPT/src/SensIndexSchurData.cpp)).
128 NotInitialized,
129 /// `set_*` called twice on the same instance. Upstream:
130 /// [`SensIndexSchurData.cpp:59-69`](../../../ref/Ipopt/contrib/sIPOPT/src/SensIndexSchurData.cpp).
131 AlreadyInitialized,
132 /// `set_from_flags` was passed a flag value other than 0/1. Upstream
133 /// asserts `flags[i] == 0 || flags[i] == 1`
134 /// ([`SensIndexSchurData.cpp:51-78`](../../../ref/Ipopt/contrib/sIPOPT/src/SensIndexSchurData.cpp));
135 /// pounce surfaces it as this distinct `Err` (not `AlreadyInitialized`,
136 /// which would mislead — the instance is *not* initialized) and leaves
137 /// the instance untouched so a corrected retry is safe.
138 InvalidFlag,
139 /// A row index was out of range (e.g. `multiplying_row(i)` with
140 /// `i >= nrows()`).
141 RowOutOfRange,
142 /// Caller-supplied buffer length didn't match the expected shape.
143 DimensionMismatch,
144 /// `v == 0` passed to a sign-only `set_*` API. Upstream asserts
145 /// `v != 0` ([`SensIndexSchurData.cpp:61`](../../../ref/Ipopt/contrib/sIPOPT/src/SensIndexSchurData.cpp)).
146 ZeroSign,
147}
148
149/// Specialization for `B` matrices whose non-zero entries are ±1
150/// (the parametric / reduced-Hessian common case). Storage is two
151/// parallel arrays mirroring upstream's `idx_` and `val_`
152/// ([`SensIndexSchurData.hpp:127-128`](../../../ref/Ipopt/contrib/sIPOPT/src/SensIndexSchurData.hpp)).
153#[derive(Debug, Clone, Default, PartialEq, Eq)]
154pub struct IndexSchurData {
155 idx: Vec<Index>,
156 /// Stored as ±1 values, matching upstream's `Index`-typed `val_`.
157 /// Kept as `Index` (not `Number`) so the sign is exact and small.
158 val: Vec<Index>,
159 initialized: bool,
160}
161
162impl IndexSchurData {
163 /// Empty / uninitialized instance. Must be populated via one of
164 /// the `set_from_*` methods before being read.
165 pub fn new() -> Self {
166 Self::default()
167 }
168
169 /// Construct directly from pre-built `(idx, val)` arrays. `val`
170 /// must contain only ±1 entries. Mirrors upstream's two-arg
171 /// constructor
172 /// ([`SensIndexSchurData.cpp:24-36`](../../../ref/Ipopt/contrib/sIPOPT/src/SensIndexSchurData.cpp)).
173 pub fn from_parts(idx: Vec<Index>, val: Vec<Index>) -> Result<Self, SchurDataError> {
174 if idx.len() != val.len() {
175 return Err(SchurDataError::DimensionMismatch);
176 }
177 if val.iter().any(|&v| v != 1 && v != -1) {
178 return Err(SchurDataError::ZeroSign);
179 }
180 Ok(Self {
181 idx,
182 val,
183 initialized: true,
184 })
185 }
186
187 /// Column indices the rows refer to (one index per row). Upstream
188 /// `GetColIndices()`
189 /// ([`SensIndexSchurData.hpp:114`](../../../ref/Ipopt/contrib/sIPOPT/src/SensIndexSchurData.hpp)).
190 pub fn col_indices(&self) -> &[Index] {
191 &self.idx
192 }
193
194 /// Per-row ±1 sign carried by the single non-zero entry in that
195 /// row. Pounce-specific accessor; upstream exposes the data only
196 /// through the `multiplying_*` / `multiply` APIs.
197 pub fn signs(&self) -> &[Index] {
198 &self.val
199 }
200}
201
202impl SchurData for IndexSchurData {
203 fn nrows(&self) -> Index {
204 self.val.len() as Index
205 }
206
207 fn is_initialized(&self) -> bool {
208 self.initialized
209 }
210
211 fn set_from_flags(&mut self, flags: &[Index], v: Number) -> Result<(), SchurDataError> {
212 if self.initialized {
213 return Err(SchurDataError::AlreadyInitialized);
214 }
215 if v == 0.0 {
216 return Err(SchurDataError::ZeroSign);
217 }
218 // Validate the whole flag array BEFORE mutating any state, so an
219 // invalid entry leaves the instance exactly as found. The previous
220 // code pushed rows as it scanned and bailed mid-loop on a bad flag,
221 // leaving `idx`/`val` partially populated with `initialized == false`
222 // — a caller that fixed the flags and retried would then append a
223 // second copy of the leading rows (duplicate rows). Upstream asserts
224 // `flags[i] ∈ {0,1}` (`SensIndexSchurData.cpp:51-78`); we surface the
225 // bad input as `InvalidFlag` rather than the misleading
226 // `AlreadyInitialized`.
227 if flags.iter().any(|&f| f != 0 && f != 1) {
228 return Err(SchurDataError::InvalidFlag);
229 }
230 let w: Index = if v > 0.0 { 1 } else { -1 };
231 for (i, &f) in flags.iter().enumerate() {
232 if f == 1 {
233 self.idx.push(i as Index);
234 self.val.push(w);
235 }
236 }
237 self.initialized = true;
238 Ok(())
239 }
240
241 fn set_from_list(&mut self, cols: &[Index], v: Number) -> Result<(), SchurDataError> {
242 if self.initialized {
243 return Err(SchurDataError::AlreadyInitialized);
244 }
245 if v == 0.0 {
246 return Err(SchurDataError::ZeroSign);
247 }
248 let w: Index = if v > 0.0 { 1 } else { -1 };
249 self.idx.extend_from_slice(cols);
250 self.val.resize(cols.len(), w);
251 self.initialized = true;
252 Ok(())
253 }
254
255 fn multiplying_row(&self, i: Index) -> Result<(Vec<Index>, Vec<Number>), SchurDataError> {
256 if !self.initialized {
257 return Err(SchurDataError::NotInitialized);
258 }
259 let i_us = i as usize;
260 if i_us >= self.idx.len() {
261 return Err(SchurDataError::RowOutOfRange);
262 }
263 Ok((vec![self.idx[i_us]], vec![self.val[i_us] as Number]))
264 }
265
266 fn multiply(&self, v: &[Number], u: &mut [Number]) -> Result<(), SchurDataError> {
267 if !self.initialized {
268 return Err(SchurDataError::NotInitialized);
269 }
270 if u.len() != self.idx.len() {
271 return Err(SchurDataError::DimensionMismatch);
272 }
273 for (i, slot) in u.iter_mut().enumerate() {
274 let col = self.idx[i] as usize;
275 if col >= v.len() {
276 return Err(SchurDataError::DimensionMismatch);
277 }
278 *slot = (self.val[i] as Number) * v[col];
279 }
280 Ok(())
281 }
282
283 fn trans_multiply(&self, u: &[Number], v: &mut [Number]) -> Result<(), SchurDataError> {
284 if !self.initialized {
285 return Err(SchurDataError::NotInitialized);
286 }
287 if u.len() != self.idx.len() {
288 return Err(SchurDataError::DimensionMismatch);
289 }
290 for slot in v.iter_mut() {
291 *slot = 0.0;
292 }
293 for (i, &row_u) in u.iter().enumerate() {
294 let col = self.idx[i] as usize;
295 if col >= v.len() {
296 return Err(SchurDataError::DimensionMismatch);
297 }
298 v[col] += (self.val[i] as Number) * row_u;
299 }
300 Ok(())
301 }
302}
303
304#[cfg(test)]
305mod tests {
306 use super::*;
307
308 /// Three-variable example: select variables 1 and 3 with sign +1.
309 /// `B = [[0 1 0 0], [0 0 0 1]]`.
310 #[test]
311 fn set_from_flags_round_trip() {
312 let mut s = IndexSchurData::new();
313 let flags = [0, 1, 0, 1];
314 s.set_from_flags(&flags, 1.0).expect("init");
315 assert_eq!(s.nrows(), 2);
316 assert_eq!(s.col_indices(), &[1, 3]);
317 assert_eq!(s.signs(), &[1, 1]);
318 assert!(s.is_initialized());
319 }
320
321 #[test]
322 fn set_from_flags_negative_sign_records_minus_one() {
323 let mut s = IndexSchurData::new();
324 s.set_from_flags(&[1, 0, 1], -2.5).expect("init");
325 assert_eq!(s.signs(), &[-1, -1]);
326 }
327
328 #[test]
329 fn set_from_flags_rejects_double_init() {
330 let mut s = IndexSchurData::new();
331 s.set_from_flags(&[1, 0], 1.0).expect("first init");
332 assert_eq!(
333 s.set_from_flags(&[0, 1], 1.0),
334 Err(SchurDataError::AlreadyInitialized),
335 );
336 }
337
338 #[test]
339 fn set_from_flags_rejects_zero_sign() {
340 let mut s = IndexSchurData::new();
341 assert_eq!(
342 s.set_from_flags(&[1, 0, 1], 0.0),
343 Err(SchurDataError::ZeroSign),
344 );
345 }
346
347 #[test]
348 fn set_from_flags_rejects_invalid_flag_with_distinct_variant() {
349 // A flag value other than 0/1 is bad *input*, not a double-init —
350 // it must surface as `InvalidFlag`, not `AlreadyInitialized`.
351 let mut s = IndexSchurData::new();
352 assert_eq!(
353 s.set_from_flags(&[1, 0, 2, 1], 1.0),
354 Err(SchurDataError::InvalidFlag),
355 );
356 }
357
358 #[test]
359 fn set_from_flags_invalid_flag_leaves_instance_pristine_for_retry() {
360 // The bad flag sits AFTER a valid `1`, so a non-atomic
361 // implementation would have already pushed row 0 before bailing.
362 // The instance must be left untouched (uninitialized, empty) so a
363 // corrected retry produces exactly the right rows — no duplicates.
364 let mut s = IndexSchurData::new();
365 assert_eq!(
366 s.set_from_flags(&[1, 0, 5], 1.0),
367 Err(SchurDataError::InvalidFlag),
368 );
369 assert!(
370 !s.is_initialized(),
371 "must stay uninitialized after a failed set"
372 );
373 assert_eq!(s.nrows(), 0, "no partial rows may linger");
374 assert_eq!(s.col_indices(), &[] as &[Index]);
375
376 // Corrected retry: selects vars 0 and 2 only.
377 s.set_from_flags(&[1, 0, 1], 1.0).expect("retry init");
378 assert_eq!(
379 s.col_indices(),
380 &[0, 2],
381 "retry must not append to leftover state"
382 );
383 assert_eq!(s.signs(), &[1, 1]);
384 assert!(s.is_initialized());
385 }
386
387 #[test]
388 fn set_from_list_records_each_column_once() {
389 let mut s = IndexSchurData::new();
390 s.set_from_list(&[2, 0, 4], 1.0).expect("init");
391 assert_eq!(s.nrows(), 3);
392 assert_eq!(s.col_indices(), &[2, 0, 4]);
393 assert_eq!(s.signs(), &[1, 1, 1]);
394 }
395
396 #[test]
397 fn from_parts_validates_signs() {
398 // Mixed ±1 OK
399 let ok = IndexSchurData::from_parts(vec![0, 2], vec![1, -1]).expect("ok");
400 assert_eq!(ok.signs(), &[1, -1]);
401 // Length mismatch
402 assert_eq!(
403 IndexSchurData::from_parts(vec![0, 2], vec![1]),
404 Err(SchurDataError::DimensionMismatch),
405 );
406 // Non-±1
407 assert_eq!(
408 IndexSchurData::from_parts(vec![0], vec![2]),
409 Err(SchurDataError::ZeroSign),
410 );
411 }
412
413 #[test]
414 fn multiply_picks_selected_columns_with_signs() {
415 // B = [[0 1 0 0], [0 0 0 -1]]
416 let s = IndexSchurData::from_parts(vec![1, 3], vec![1, -1]).expect("ok");
417 let v = [10.0, 20.0, 30.0, 40.0];
418 let mut u = [0.0; 2];
419 s.multiply(&v, &mut u).expect("ok");
420 // u[0] = +1·v[1] = 20
421 // u[1] = -1·v[3] = -40
422 assert_eq!(u, [20.0, -40.0]);
423 }
424
425 #[test]
426 fn trans_multiply_scatters_with_signs() {
427 // B from previous test; Bᵀ u with u = [3, 5] should produce
428 // v = (0, 3, 0, -5).
429 let s = IndexSchurData::from_parts(vec![1, 3], vec![1, -1]).expect("ok");
430 let u = [3.0, 5.0];
431 let mut v = [0.0; 4];
432 s.trans_multiply(&u, &mut v).expect("ok");
433 assert_eq!(v, [0.0, 3.0, 0.0, -5.0]);
434 }
435
436 #[test]
437 fn trans_multiply_overwrites_caller_buffer() {
438 // Existing entries in v are zeroed before the scatter.
439 let s = IndexSchurData::from_parts(vec![0, 2], vec![1, 1]).expect("ok");
440 let u = [1.0, 2.0];
441 let mut v = [99.0, 99.0, 99.0, 99.0];
442 s.trans_multiply(&u, &mut v).expect("ok");
443 assert_eq!(v, [1.0, 0.0, 2.0, 0.0]);
444 }
445
446 #[test]
447 fn multiply_rejects_uninitialized() {
448 let s = IndexSchurData::new();
449 let v = [0.0];
450 let mut u = [0.0];
451 assert_eq!(s.multiply(&v, &mut u), Err(SchurDataError::NotInitialized),);
452 }
453
454 #[test]
455 fn multiplying_row_out_of_range() {
456 let s = IndexSchurData::from_parts(vec![0], vec![1]).expect("ok");
457 assert_eq!(s.multiplying_row(2), Err(SchurDataError::RowOutOfRange),);
458 }
459
460 #[test]
461 fn multiplying_row_returns_single_entry_for_index_schur_data() {
462 // Per upstream `SensIndexSchurData.cpp:199-212`, `IndexSchurData`
463 // rows always have exactly one non-zero entry — pounce mirrors
464 // that contract.
465 let s = IndexSchurData::from_parts(vec![5, 7], vec![1, -1]).expect("ok");
466 let (idxs, facs) = s.multiplying_row(0).expect("ok");
467 assert_eq!(idxs, &[5]);
468 assert_eq!(facs, &[1.0]);
469 let (idxs, facs) = s.multiplying_row(1).expect("ok");
470 assert_eq!(idxs, &[7]);
471 assert_eq!(facs, &[-1.0]);
472 }
473}