pounce_sensitivity/schur_data.rs
1//! `SchurData` trait surface and the `IndexSchurData` flavor.
2//!
3//! Direct port of upstream
4//! [`SensSchurData.hpp`](../../../ref/Ipopt/contrib/sIPOPT/src/SensSchurData.hpp)
5//! (interface) and
6//! [`SensIndexSchurData.{hpp,cpp}`](../../../ref/Ipopt/contrib/sIPOPT/src/SensIndexSchurData.hpp)
7//! (the index-only specialization).
8//!
9//! # What `SchurData` represents
10//!
11//! `SchurData` is the matrix `B` in the augmented system
12//!
13//! ```text
14//! ⎡ K A ⎤
15//! ⎣ B 0 ⎦
16//! ```
17//!
18//! used by the sIPOPT step calculation (Pirnay, López-Negrete & Biegler 2012,
19//! §2, eq. 4). `K` is the converged KKT matrix from the original IPM solve;
20//! `A` and `B` carry the parameter-perturbation rows. The trait is the
21//! minimum surface every backend needs to expose so the
22//! `PCalculator` / `SchurDriver` family can stay matrix-shape-agnostic
23//! ([`SensSchurData.hpp:17-25`](../../../ref/Ipopt/contrib/sIPOPT/src/SensSchurData.hpp)).
24//!
25//! # `IndexSchurData`
26//!
27//! Specialization for parameter rows whose only non-zero entries are
28//! ±1 ([`SensIndexSchurData.hpp:15-19`](../../../ref/Ipopt/contrib/sIPOPT/src/SensIndexSchurData.hpp)).
29//! Most parametric / reduced-Hessian use cases fit this shape — the
30//! parameter just picks out a subset of primal/dual variables — so the
31//! ±1 sparsification is what production sIPOPT runs on. Pounce's
32//! port stores parallel `Vec<i32>` arrays of indices and signs, same
33//! as upstream's `idx_` / `val_`
34//! ([`SensIndexSchurData.hpp:127-128`](../../../ref/Ipopt/contrib/sIPOPT/src/SensIndexSchurData.hpp)).
35
36use pounce_common::types::{Index, Number};
37
38/// Minimum surface for any matrix that lives in the augmented sIPOPT
39/// system's `A` / `B` slots. The numerical drivers in this crate
40/// (`PCalculator`, `SchurDriver`, `SensStepCalc`) consume `SchurData`
41/// objects and never touch the storage shape directly.
42///
43/// Mirrors `Ipopt::SchurData` from
44/// [`SensSchurData.hpp:29-178`](../../../ref/Ipopt/contrib/sIPOPT/src/SensSchurData.hpp).
45///
46/// # Lifecycle
47///
48/// A `SchurData` instance starts uninitialized. Exactly one of the
49/// `set_*` methods is called to populate it; subsequent reads via
50/// `nrows`, `multiply`, `trans_multiply`, etc. require an
51/// initialized instance. Upstream enforces this via `Set_Initialized`
52/// asserts in DBG builds
53/// ([`SensIndexSchurData.cpp:59-60`](../../../ref/Ipopt/contrib/sIPOPT/src/SensIndexSchurData.cpp));
54/// pounce mirrors that invariant via `Result` returns on the read
55/// surface so a mis-ordered call surfaces as `Err(SchurDataError::NotInitialized)`
56/// instead of a panic.
57pub trait SchurData {
58 /// Number of rows the schur matrix has, i.e. the row count of `B`.
59 /// Upstream `GetNRowsAdded()`
60 /// ([`SensSchurData.hpp:77-80`](../../../ref/Ipopt/contrib/sIPOPT/src/SensSchurData.hpp)).
61 fn nrows(&self) -> Index;
62
63 /// `true` if one of the `set_*` methods has been called and the
64 /// rest of the surface is safe to call. Upstream `Is_Initialized()`
65 /// ([`SensSchurData.hpp:82-85`](../../../ref/Ipopt/contrib/sIPOPT/src/SensSchurData.hpp)).
66 fn is_initialized(&self) -> bool;
67
68 /// Set rows from a 0/1 flag array of length `dim`. For each `i`
69 /// with `flags[i] == 1`, add a row whose only non-zero column is
70 /// `i` with sign `sign(v)`. The magnitude of `v` is collapsed to
71 /// ±1 — this trait only ever stores signs, mirroring upstream's
72 /// `SetData_Flag(dim, flags, v)`
73 /// ([`SensSchurData.hpp:45-49`](../../../ref/Ipopt/contrib/sIPOPT/src/SensSchurData.hpp),
74 /// [`SensIndexSchurData.cpp:51-78`](../../../ref/Ipopt/contrib/sIPOPT/src/SensIndexSchurData.cpp)).
75 ///
76 /// Returns `Err(_)` if the instance was already initialized.
77 fn set_from_flags(&mut self, flags: &[Index], v: Number) -> Result<(), SchurDataError>;
78
79 /// Set rows from a list of column indices. Each `cols[k]` becomes
80 /// row `k` of `B`, with the single non-zero entry at column
81 /// `cols[k]` carrying sign `sign(v)`. Upstream `SetData_List`
82 /// ([`SensSchurData.hpp:64-67`](../../../ref/Ipopt/contrib/sIPOPT/src/SensSchurData.hpp),
83 /// [`SensIndexSchurData.cpp:149-167`](../../../ref/Ipopt/contrib/sIPOPT/src/SensIndexSchurData.cpp)).
84 ///
85 /// Returns `Err(_)` if already initialized.
86 fn set_from_list(&mut self, cols: &[Index], v: Number) -> Result<(), SchurDataError>;
87
88 /// Row-`i` access: return the parallel arrays `(indices, factors)`
89 /// such that row `i` of `B` has non-zero column entries
90 /// `factors[j]` at columns `indices[j]`. For `IndexSchurData`
91 /// `indices` has length 1 and `factors[0] == ±1.0`. Upstream
92 /// `GetMultiplyingVectors`
93 /// ([`SensSchurData.hpp:93-104`](../../../ref/Ipopt/contrib/sIPOPT/src/SensSchurData.hpp),
94 /// [`SensIndexSchurData.cpp:199-212`](../../../ref/Ipopt/contrib/sIPOPT/src/SensIndexSchurData.cpp)).
95 fn multiplying_row(&self, i: Index) -> Result<(Vec<Index>, Vec<Number>), SchurDataError>;
96
97 /// Apply `u = B v` for a `v` of length `n_full` and pre-sized
98 /// `u` buffer (length must equal `self.nrows()`). Upstream
99 /// `Multiply(IteratesVector, Vector)`
100 /// ([`SensSchurData.hpp:106-110`](../../../ref/Ipopt/contrib/sIPOPT/src/SensSchurData.hpp),
101 /// [`SensIndexSchurData.cpp:214-251`](../../../ref/Ipopt/contrib/sIPOPT/src/SensIndexSchurData.cpp)).
102 ///
103 /// In pounce we operate on flat `&[Number]` instead of upstream's
104 /// `IteratesVector` block layout because Phase-A `SchurData` is
105 /// shape-agnostic. Phase-B reconstructs the block layout where
106 /// the algorithm-side needs it.
107 fn multiply(&self, v: &[Number], u: &mut [Number]) -> Result<(), SchurDataError>;
108
109 /// Apply `v = Bᵀ u` for a `u` of length `self.nrows()` and a
110 /// pre-sized `v` buffer (length `n_full`). Upstream
111 /// `TransMultiply`
112 /// ([`SensSchurData.hpp:112-116`](../../../ref/Ipopt/contrib/sIPOPT/src/SensSchurData.hpp),
113 /// [`SensIndexSchurData.cpp:253-307`](../../../ref/Ipopt/contrib/sIPOPT/src/SensIndexSchurData.cpp)).
114 fn trans_multiply(&self, u: &[Number], v: &mut [Number]) -> Result<(), SchurDataError>;
115}
116
117/// Failure modes returned by [`SchurData`] read/write entry points.
118/// Pounce returns these as `Err(_)` where upstream's debug asserts
119/// would `DBG_ASSERT`.
120#[derive(Debug, Clone, Copy, PartialEq, Eq)]
121pub enum SchurDataError {
122 /// A read method was called before a `set_*` method initialized
123 /// the instance. Upstream asserts `Is_Initialized()` in DBG builds
124 /// (e.g. [`SensIndexSchurData.cpp:176`](../../../ref/Ipopt/contrib/sIPOPT/src/SensIndexSchurData.cpp)).
125 NotInitialized,
126 /// `set_*` called twice on the same instance, or `flags` contained
127 /// values other than 0/1. Upstream:
128 /// [`SensIndexSchurData.cpp:59-69`](../../../ref/Ipopt/contrib/sIPOPT/src/SensIndexSchurData.cpp).
129 AlreadyInitialized,
130 /// A row index was out of range (e.g. `multiplying_row(i)` with
131 /// `i >= nrows()`).
132 RowOutOfRange,
133 /// Caller-supplied buffer length didn't match the expected shape.
134 DimensionMismatch,
135 /// `v == 0` passed to a sign-only `set_*` API. Upstream asserts
136 /// `v != 0` ([`SensIndexSchurData.cpp:61`](../../../ref/Ipopt/contrib/sIPOPT/src/SensIndexSchurData.cpp)).
137 ZeroSign,
138}
139
140/// Specialization for `B` matrices whose non-zero entries are ±1
141/// (the parametric / reduced-Hessian common case). Storage is two
142/// parallel arrays mirroring upstream's `idx_` and `val_`
143/// ([`SensIndexSchurData.hpp:127-128`](../../../ref/Ipopt/contrib/sIPOPT/src/SensIndexSchurData.hpp)).
144#[derive(Debug, Clone, Default, PartialEq, Eq)]
145pub struct IndexSchurData {
146 idx: Vec<Index>,
147 /// Stored as ±1 values, matching upstream's `Index`-typed `val_`.
148 /// Kept as `Index` (not `Number`) so the sign is exact and small.
149 val: Vec<Index>,
150 initialized: bool,
151}
152
153impl IndexSchurData {
154 /// Empty / uninitialized instance. Must be populated via one of
155 /// the `set_from_*` methods before being read.
156 pub fn new() -> Self {
157 Self::default()
158 }
159
160 /// Construct directly from pre-built `(idx, val)` arrays. `val`
161 /// must contain only ±1 entries. Mirrors upstream's two-arg
162 /// constructor
163 /// ([`SensIndexSchurData.cpp:24-36`](../../../ref/Ipopt/contrib/sIPOPT/src/SensIndexSchurData.cpp)).
164 pub fn from_parts(idx: Vec<Index>, val: Vec<Index>) -> Result<Self, SchurDataError> {
165 if idx.len() != val.len() {
166 return Err(SchurDataError::DimensionMismatch);
167 }
168 if val.iter().any(|&v| v != 1 && v != -1) {
169 return Err(SchurDataError::ZeroSign);
170 }
171 Ok(Self {
172 idx,
173 val,
174 initialized: true,
175 })
176 }
177
178 /// Column indices the rows refer to (one index per row). Upstream
179 /// `GetColIndices()`
180 /// ([`SensIndexSchurData.hpp:114`](../../../ref/Ipopt/contrib/sIPOPT/src/SensIndexSchurData.hpp)).
181 pub fn col_indices(&self) -> &[Index] {
182 &self.idx
183 }
184
185 /// Per-row ±1 sign carried by the single non-zero entry in that
186 /// row. Pounce-specific accessor; upstream exposes the data only
187 /// through the `multiplying_*` / `multiply` APIs.
188 pub fn signs(&self) -> &[Index] {
189 &self.val
190 }
191}
192
193impl SchurData for IndexSchurData {
194 fn nrows(&self) -> Index {
195 self.val.len() as Index
196 }
197
198 fn is_initialized(&self) -> bool {
199 self.initialized
200 }
201
202 fn set_from_flags(&mut self, flags: &[Index], v: Number) -> Result<(), SchurDataError> {
203 if self.initialized {
204 return Err(SchurDataError::AlreadyInitialized);
205 }
206 if v == 0.0 {
207 return Err(SchurDataError::ZeroSign);
208 }
209 let w: Index = if v > 0.0 { 1 } else { -1 };
210 for (i, &f) in flags.iter().enumerate() {
211 match f {
212 0 => {}
213 1 => {
214 self.idx.push(i as Index);
215 self.val.push(w);
216 }
217 _ => return Err(SchurDataError::AlreadyInitialized), // upstream asserts flag ∈ {0,1}
218 }
219 }
220 self.initialized = true;
221 Ok(())
222 }
223
224 fn set_from_list(&mut self, cols: &[Index], v: Number) -> Result<(), SchurDataError> {
225 if self.initialized {
226 return Err(SchurDataError::AlreadyInitialized);
227 }
228 if v == 0.0 {
229 return Err(SchurDataError::ZeroSign);
230 }
231 let w: Index = if v > 0.0 { 1 } else { -1 };
232 self.idx.extend_from_slice(cols);
233 self.val.resize(cols.len(), w);
234 self.initialized = true;
235 Ok(())
236 }
237
238 fn multiplying_row(&self, i: Index) -> Result<(Vec<Index>, Vec<Number>), SchurDataError> {
239 if !self.initialized {
240 return Err(SchurDataError::NotInitialized);
241 }
242 let i_us = i as usize;
243 if i_us >= self.idx.len() {
244 return Err(SchurDataError::RowOutOfRange);
245 }
246 Ok((vec![self.idx[i_us]], vec![self.val[i_us] as Number]))
247 }
248
249 fn multiply(&self, v: &[Number], u: &mut [Number]) -> Result<(), SchurDataError> {
250 if !self.initialized {
251 return Err(SchurDataError::NotInitialized);
252 }
253 if u.len() != self.idx.len() {
254 return Err(SchurDataError::DimensionMismatch);
255 }
256 for (i, slot) in u.iter_mut().enumerate() {
257 let col = self.idx[i] as usize;
258 if col >= v.len() {
259 return Err(SchurDataError::DimensionMismatch);
260 }
261 *slot = (self.val[i] as Number) * v[col];
262 }
263 Ok(())
264 }
265
266 fn trans_multiply(&self, u: &[Number], v: &mut [Number]) -> Result<(), SchurDataError> {
267 if !self.initialized {
268 return Err(SchurDataError::NotInitialized);
269 }
270 if u.len() != self.idx.len() {
271 return Err(SchurDataError::DimensionMismatch);
272 }
273 for slot in v.iter_mut() {
274 *slot = 0.0;
275 }
276 for (i, &row_u) in u.iter().enumerate() {
277 let col = self.idx[i] as usize;
278 if col >= v.len() {
279 return Err(SchurDataError::DimensionMismatch);
280 }
281 v[col] += (self.val[i] as Number) * row_u;
282 }
283 Ok(())
284 }
285}
286
287#[cfg(test)]
288mod tests {
289 use super::*;
290
291 /// Three-variable example: select variables 1 and 3 with sign +1.
292 /// `B = [[0 1 0 0], [0 0 0 1]]`.
293 #[test]
294 fn set_from_flags_round_trip() {
295 let mut s = IndexSchurData::new();
296 let flags = [0, 1, 0, 1];
297 s.set_from_flags(&flags, 1.0).expect("init");
298 assert_eq!(s.nrows(), 2);
299 assert_eq!(s.col_indices(), &[1, 3]);
300 assert_eq!(s.signs(), &[1, 1]);
301 assert!(s.is_initialized());
302 }
303
304 #[test]
305 fn set_from_flags_negative_sign_records_minus_one() {
306 let mut s = IndexSchurData::new();
307 s.set_from_flags(&[1, 0, 1], -2.5).expect("init");
308 assert_eq!(s.signs(), &[-1, -1]);
309 }
310
311 #[test]
312 fn set_from_flags_rejects_double_init() {
313 let mut s = IndexSchurData::new();
314 s.set_from_flags(&[1, 0], 1.0).expect("first init");
315 assert_eq!(
316 s.set_from_flags(&[0, 1], 1.0),
317 Err(SchurDataError::AlreadyInitialized),
318 );
319 }
320
321 #[test]
322 fn set_from_flags_rejects_zero_sign() {
323 let mut s = IndexSchurData::new();
324 assert_eq!(
325 s.set_from_flags(&[1, 0, 1], 0.0),
326 Err(SchurDataError::ZeroSign),
327 );
328 }
329
330 #[test]
331 fn set_from_list_records_each_column_once() {
332 let mut s = IndexSchurData::new();
333 s.set_from_list(&[2, 0, 4], 1.0).expect("init");
334 assert_eq!(s.nrows(), 3);
335 assert_eq!(s.col_indices(), &[2, 0, 4]);
336 assert_eq!(s.signs(), &[1, 1, 1]);
337 }
338
339 #[test]
340 fn from_parts_validates_signs() {
341 // Mixed ±1 OK
342 let ok = IndexSchurData::from_parts(vec![0, 2], vec![1, -1]).expect("ok");
343 assert_eq!(ok.signs(), &[1, -1]);
344 // Length mismatch
345 assert_eq!(
346 IndexSchurData::from_parts(vec![0, 2], vec![1]),
347 Err(SchurDataError::DimensionMismatch),
348 );
349 // Non-±1
350 assert_eq!(
351 IndexSchurData::from_parts(vec![0], vec![2]),
352 Err(SchurDataError::ZeroSign),
353 );
354 }
355
356 #[test]
357 fn multiply_picks_selected_columns_with_signs() {
358 // B = [[0 1 0 0], [0 0 0 -1]]
359 let s = IndexSchurData::from_parts(vec![1, 3], vec![1, -1]).expect("ok");
360 let v = [10.0, 20.0, 30.0, 40.0];
361 let mut u = [0.0; 2];
362 s.multiply(&v, &mut u).expect("ok");
363 // u[0] = +1·v[1] = 20
364 // u[1] = -1·v[3] = -40
365 assert_eq!(u, [20.0, -40.0]);
366 }
367
368 #[test]
369 fn trans_multiply_scatters_with_signs() {
370 // B from previous test; Bᵀ u with u = [3, 5] should produce
371 // v = (0, 3, 0, -5).
372 let s = IndexSchurData::from_parts(vec![1, 3], vec![1, -1]).expect("ok");
373 let u = [3.0, 5.0];
374 let mut v = [0.0; 4];
375 s.trans_multiply(&u, &mut v).expect("ok");
376 assert_eq!(v, [0.0, 3.0, 0.0, -5.0]);
377 }
378
379 #[test]
380 fn trans_multiply_overwrites_caller_buffer() {
381 // Existing entries in v are zeroed before the scatter.
382 let s = IndexSchurData::from_parts(vec![0, 2], vec![1, 1]).expect("ok");
383 let u = [1.0, 2.0];
384 let mut v = [99.0, 99.0, 99.0, 99.0];
385 s.trans_multiply(&u, &mut v).expect("ok");
386 assert_eq!(v, [1.0, 0.0, 2.0, 0.0]);
387 }
388
389 #[test]
390 fn multiply_rejects_uninitialized() {
391 let s = IndexSchurData::new();
392 let v = [0.0];
393 let mut u = [0.0];
394 assert_eq!(s.multiply(&v, &mut u), Err(SchurDataError::NotInitialized),);
395 }
396
397 #[test]
398 fn multiplying_row_out_of_range() {
399 let s = IndexSchurData::from_parts(vec![0], vec![1]).expect("ok");
400 assert_eq!(s.multiplying_row(2), Err(SchurDataError::RowOutOfRange),);
401 }
402
403 #[test]
404 fn multiplying_row_returns_single_entry_for_index_schur_data() {
405 // Per upstream `SensIndexSchurData.cpp:199-212`, `IndexSchurData`
406 // rows always have exactly one non-zero entry — pounce mirrors
407 // that contract.
408 let s = IndexSchurData::from_parts(vec![5, 7], vec![1, -1]).expect("ok");
409 let (idxs, facs) = s.multiplying_row(0).expect("ok");
410 assert_eq!(idxs, &[5]);
411 assert_eq!(facs, &[1.0]);
412 let (idxs, facs) = s.multiplying_row(1).expect("ok");
413 assert_eq!(idxs, &[7]);
414 assert_eq!(facs, &[-1.0]);
415 }
416}