Skip to main content

pounce_sensitivity/
schur_data.rs

1//! `SchurData` trait surface and the `IndexSchurData` flavor.
2//!
3//! Direct port of upstream
4//! [`SensSchurData.hpp`](../../../ref/Ipopt/contrib/sIPOPT/src/SensSchurData.hpp)
5//! (interface) and
6//! [`SensIndexSchurData.{hpp,cpp}`](../../../ref/Ipopt/contrib/sIPOPT/src/SensIndexSchurData.hpp)
7//! (the index-only specialization).
8//!
9//! # What `SchurData` represents
10//!
11//! `SchurData` is the matrix `B` in the augmented system
12//!
13//! ```text
14//! ⎡ K   A ⎤
15//! ⎣ B   0 ⎦
16//! ```
17//!
18//! used by the sIPOPT step calculation (Pirnay, López-Negrete & Biegler 2012,
19//! §2, eq. 4). `K` is the converged KKT matrix from the original IPM solve;
20//! `A` and `B` carry the parameter-perturbation rows. The trait is the
21//! minimum surface every backend needs to expose so the
22//! `PCalculator` / `SchurDriver` family can stay matrix-shape-agnostic
23//! ([`SensSchurData.hpp:17-25`](../../../ref/Ipopt/contrib/sIPOPT/src/SensSchurData.hpp)).
24//!
25//! # `IndexSchurData`
26//!
27//! Specialization for parameter rows whose only non-zero entries are
28//! ±1 ([`SensIndexSchurData.hpp:15-19`](../../../ref/Ipopt/contrib/sIPOPT/src/SensIndexSchurData.hpp)).
29//! Most parametric / reduced-Hessian use cases fit this shape — the
30//! parameter just picks out a subset of primal/dual variables — so the
31//! ±1 sparsification is what production sIPOPT runs on. Pounce's
32//! port stores parallel `Vec<i32>` arrays of indices and signs, same
33//! as upstream's `idx_` / `val_`
34//! ([`SensIndexSchurData.hpp:127-128`](../../../ref/Ipopt/contrib/sIPOPT/src/SensIndexSchurData.hpp)).
35
36use pounce_common::types::{Index, Number};
37
38/// Minimum surface for any matrix that lives in the augmented sIPOPT
39/// system's `A` / `B` slots. The numerical drivers in this crate
40/// (`PCalculator`, `SchurDriver`, `SensStepCalc`) consume `SchurData`
41/// objects and never touch the storage shape directly.
42///
43/// Mirrors `Ipopt::SchurData` from
44/// [`SensSchurData.hpp:29-178`](../../../ref/Ipopt/contrib/sIPOPT/src/SensSchurData.hpp).
45///
46/// # Lifecycle
47///
48/// A `SchurData` instance starts uninitialized. Exactly one of the
49/// `set_*` methods is called to populate it; subsequent reads via
50/// `nrows`, `multiply`, `trans_multiply`, etc. require an
51/// initialized instance. Upstream enforces this via `Set_Initialized`
52/// asserts in DBG builds
53/// ([`SensIndexSchurData.cpp:59-60`](../../../ref/Ipopt/contrib/sIPOPT/src/SensIndexSchurData.cpp));
54/// pounce mirrors that invariant via `Result` returns on the read
55/// surface so a mis-ordered call surfaces as `Err(SchurDataError::NotInitialized)`
56/// instead of a panic.
57pub trait SchurData {
58    /// Number of rows the schur matrix has, i.e. the row count of `B`.
59    /// Upstream `GetNRowsAdded()`
60    /// ([`SensSchurData.hpp:77-80`](../../../ref/Ipopt/contrib/sIPOPT/src/SensSchurData.hpp)).
61    fn nrows(&self) -> Index;
62
63    /// `true` if one of the `set_*` methods has been called and the
64    /// rest of the surface is safe to call. Upstream `Is_Initialized()`
65    /// ([`SensSchurData.hpp:82-85`](../../../ref/Ipopt/contrib/sIPOPT/src/SensSchurData.hpp)).
66    fn is_initialized(&self) -> bool;
67
68    /// Set rows from a 0/1 flag array of length `dim`. For each `i`
69    /// with `flags[i] == 1`, add a row whose only non-zero column is
70    /// `i` with sign `sign(v)`. The magnitude of `v` is collapsed to
71    /// ±1 — this trait only ever stores signs, mirroring upstream's
72    /// `SetData_Flag(dim, flags, v)`
73    /// ([`SensSchurData.hpp:45-49`](../../../ref/Ipopt/contrib/sIPOPT/src/SensSchurData.hpp),
74    /// [`SensIndexSchurData.cpp:51-78`](../../../ref/Ipopt/contrib/sIPOPT/src/SensIndexSchurData.cpp)).
75    ///
76    /// Returns `Err(SchurDataError::AlreadyInitialized)` if the instance
77    /// was already initialized, or `Err(SchurDataError::InvalidFlag)` if any
78    /// `flags[i]` is not 0/1. On either error the instance is left
79    /// unchanged, so a corrected retry is safe.
80    fn set_from_flags(&mut self, flags: &[Index], v: Number) -> Result<(), SchurDataError>;
81
82    /// Set rows from a list of column indices. Each `cols[k]` becomes
83    /// row `k` of `B`, with the single non-zero entry at column
84    /// `cols[k]` carrying sign `sign(v)`. Upstream `SetData_List`
85    /// ([`SensSchurData.hpp:64-67`](../../../ref/Ipopt/contrib/sIPOPT/src/SensSchurData.hpp),
86    /// [`SensIndexSchurData.cpp:149-167`](../../../ref/Ipopt/contrib/sIPOPT/src/SensIndexSchurData.cpp)).
87    ///
88    /// Returns `Err(_)` if already initialized.
89    fn set_from_list(&mut self, cols: &[Index], v: Number) -> Result<(), SchurDataError>;
90
91    /// Row-`i` access: return the parallel arrays `(indices, factors)`
92    /// such that row `i` of `B` has non-zero column entries
93    /// `factors[j]` at columns `indices[j]`. For `IndexSchurData`
94    /// `indices` has length 1 and `factors[0] == ±1.0`. Upstream
95    /// `GetMultiplyingVectors`
96    /// ([`SensSchurData.hpp:93-104`](../../../ref/Ipopt/contrib/sIPOPT/src/SensSchurData.hpp),
97    /// [`SensIndexSchurData.cpp:199-212`](../../../ref/Ipopt/contrib/sIPOPT/src/SensIndexSchurData.cpp)).
98    fn multiplying_row(&self, i: Index) -> Result<(Vec<Index>, Vec<Number>), SchurDataError>;
99
100    /// Apply `u = B v` for a `v` of length `n_full` and pre-sized
101    /// `u` buffer (length must equal `self.nrows()`). Upstream
102    /// `Multiply(IteratesVector, Vector)`
103    /// ([`SensSchurData.hpp:106-110`](../../../ref/Ipopt/contrib/sIPOPT/src/SensSchurData.hpp),
104    /// [`SensIndexSchurData.cpp:214-251`](../../../ref/Ipopt/contrib/sIPOPT/src/SensIndexSchurData.cpp)).
105    ///
106    /// In pounce we operate on flat `&[Number]` instead of upstream's
107    /// `IteratesVector` block layout because Phase-A `SchurData` is
108    /// shape-agnostic. Phase-B reconstructs the block layout where
109    /// the algorithm-side needs it.
110    fn multiply(&self, v: &[Number], u: &mut [Number]) -> Result<(), SchurDataError>;
111
112    /// Apply `v = Bᵀ u` for a `u` of length `self.nrows()` and a
113    /// pre-sized `v` buffer (length `n_full`). Upstream
114    /// `TransMultiply`
115    /// ([`SensSchurData.hpp:112-116`](../../../ref/Ipopt/contrib/sIPOPT/src/SensSchurData.hpp),
116    /// [`SensIndexSchurData.cpp:253-307`](../../../ref/Ipopt/contrib/sIPOPT/src/SensIndexSchurData.cpp)).
117    fn trans_multiply(&self, u: &[Number], v: &mut [Number]) -> Result<(), SchurDataError>;
118}
119
120/// Failure modes returned by [`SchurData`] read/write entry points.
121/// Pounce returns these as `Err(_)` where upstream's debug asserts
122/// would `DBG_ASSERT`.
123#[derive(Debug, Clone, Copy, PartialEq, Eq)]
124pub enum SchurDataError {
125    /// A read method was called before a `set_*` method initialized
126    /// the instance. Upstream asserts `Is_Initialized()` in DBG builds
127    /// (e.g. [`SensIndexSchurData.cpp:176`](../../../ref/Ipopt/contrib/sIPOPT/src/SensIndexSchurData.cpp)).
128    NotInitialized,
129    /// `set_*` called twice on the same instance. Upstream:
130    /// [`SensIndexSchurData.cpp:59-69`](../../../ref/Ipopt/contrib/sIPOPT/src/SensIndexSchurData.cpp).
131    AlreadyInitialized,
132    /// `set_from_flags` was passed a flag value other than 0/1. Upstream
133    /// asserts `flags[i] == 0 || flags[i] == 1`
134    /// ([`SensIndexSchurData.cpp:51-78`](../../../ref/Ipopt/contrib/sIPOPT/src/SensIndexSchurData.cpp));
135    /// pounce surfaces it as this distinct `Err` (not `AlreadyInitialized`,
136    /// which would mislead — the instance is *not* initialized) and leaves
137    /// the instance untouched so a corrected retry is safe.
138    InvalidFlag,
139    /// A row index was out of range (e.g. `multiplying_row(i)` with
140    /// `i >= nrows()`).
141    RowOutOfRange,
142    /// Caller-supplied buffer length didn't match the expected shape.
143    DimensionMismatch,
144    /// `v == 0` passed to a sign-only `set_*` API. Upstream asserts
145    /// `v != 0` ([`SensIndexSchurData.cpp:61`](../../../ref/Ipopt/contrib/sIPOPT/src/SensIndexSchurData.cpp)).
146    ZeroSign,
147}
148
149/// Specialization for `B` matrices whose non-zero entries are ±1
150/// (the parametric / reduced-Hessian common case). Storage is two
151/// parallel arrays mirroring upstream's `idx_` and `val_`
152/// ([`SensIndexSchurData.hpp:127-128`](../../../ref/Ipopt/contrib/sIPOPT/src/SensIndexSchurData.hpp)).
153#[derive(Debug, Clone, Default, PartialEq, Eq)]
154pub struct IndexSchurData {
155    idx: Vec<Index>,
156    /// Stored as ±1 values, matching upstream's `Index`-typed `val_`.
157    /// Kept as `Index` (not `Number`) so the sign is exact and small.
158    val: Vec<Index>,
159    initialized: bool,
160}
161
162impl IndexSchurData {
163    /// Empty / uninitialized instance. Must be populated via one of
164    /// the `set_from_*` methods before being read.
165    pub fn new() -> Self {
166        Self::default()
167    }
168
169    /// Construct directly from pre-built `(idx, val)` arrays. `val`
170    /// must contain only ±1 entries. Mirrors upstream's two-arg
171    /// constructor
172    /// ([`SensIndexSchurData.cpp:24-36`](../../../ref/Ipopt/contrib/sIPOPT/src/SensIndexSchurData.cpp)).
173    pub fn from_parts(idx: Vec<Index>, val: Vec<Index>) -> Result<Self, SchurDataError> {
174        if idx.len() != val.len() {
175            return Err(SchurDataError::DimensionMismatch);
176        }
177        if val.iter().any(|&v| v != 1 && v != -1) {
178            return Err(SchurDataError::ZeroSign);
179        }
180        Ok(Self {
181            idx,
182            val,
183            initialized: true,
184        })
185    }
186
187    /// Column indices the rows refer to (one index per row). Upstream
188    /// `GetColIndices()`
189    /// ([`SensIndexSchurData.hpp:114`](../../../ref/Ipopt/contrib/sIPOPT/src/SensIndexSchurData.hpp)).
190    pub fn col_indices(&self) -> &[Index] {
191        &self.idx
192    }
193
194    /// Per-row ±1 sign carried by the single non-zero entry in that
195    /// row. Pounce-specific accessor; upstream exposes the data only
196    /// through the `multiplying_*` / `multiply` APIs.
197    pub fn signs(&self) -> &[Index] {
198        &self.val
199    }
200}
201
202impl SchurData for IndexSchurData {
203    fn nrows(&self) -> Index {
204        self.val.len() as Index
205    }
206
207    fn is_initialized(&self) -> bool {
208        self.initialized
209    }
210
211    fn set_from_flags(&mut self, flags: &[Index], v: Number) -> Result<(), SchurDataError> {
212        if self.initialized {
213            return Err(SchurDataError::AlreadyInitialized);
214        }
215        if v == 0.0 {
216            return Err(SchurDataError::ZeroSign);
217        }
218        // Validate the whole flag array BEFORE mutating any state, so an
219        // invalid entry leaves the instance exactly as found. The previous
220        // code pushed rows as it scanned and bailed mid-loop on a bad flag,
221        // leaving `idx`/`val` partially populated with `initialized == false`
222        // — a caller that fixed the flags and retried would then append a
223        // second copy of the leading rows (duplicate rows). Upstream asserts
224        // `flags[i] ∈ {0,1}` (`SensIndexSchurData.cpp:51-78`); we surface the
225        // bad input as `InvalidFlag` rather than the misleading
226        // `AlreadyInitialized`.
227        if flags.iter().any(|&f| f != 0 && f != 1) {
228            return Err(SchurDataError::InvalidFlag);
229        }
230        let w: Index = if v > 0.0 { 1 } else { -1 };
231        for (i, &f) in flags.iter().enumerate() {
232            if f == 1 {
233                self.idx.push(i as Index);
234                self.val.push(w);
235            }
236        }
237        self.initialized = true;
238        Ok(())
239    }
240
241    fn set_from_list(&mut self, cols: &[Index], v: Number) -> Result<(), SchurDataError> {
242        if self.initialized {
243            return Err(SchurDataError::AlreadyInitialized);
244        }
245        if v == 0.0 {
246            return Err(SchurDataError::ZeroSign);
247        }
248        let w: Index = if v > 0.0 { 1 } else { -1 };
249        self.idx.extend_from_slice(cols);
250        self.val.resize(cols.len(), w);
251        self.initialized = true;
252        Ok(())
253    }
254
255    fn multiplying_row(&self, i: Index) -> Result<(Vec<Index>, Vec<Number>), SchurDataError> {
256        if !self.initialized {
257            return Err(SchurDataError::NotInitialized);
258        }
259        let i_us = i as usize;
260        if i_us >= self.idx.len() {
261            return Err(SchurDataError::RowOutOfRange);
262        }
263        Ok((vec![self.idx[i_us]], vec![self.val[i_us] as Number]))
264    }
265
266    fn multiply(&self, v: &[Number], u: &mut [Number]) -> Result<(), SchurDataError> {
267        if !self.initialized {
268            return Err(SchurDataError::NotInitialized);
269        }
270        if u.len() != self.idx.len() {
271            return Err(SchurDataError::DimensionMismatch);
272        }
273        for (i, slot) in u.iter_mut().enumerate() {
274            let col = self.idx[i] as usize;
275            if col >= v.len() {
276                return Err(SchurDataError::DimensionMismatch);
277            }
278            *slot = (self.val[i] as Number) * v[col];
279        }
280        Ok(())
281    }
282
283    fn trans_multiply(&self, u: &[Number], v: &mut [Number]) -> Result<(), SchurDataError> {
284        if !self.initialized {
285            return Err(SchurDataError::NotInitialized);
286        }
287        if u.len() != self.idx.len() {
288            return Err(SchurDataError::DimensionMismatch);
289        }
290        for slot in v.iter_mut() {
291            *slot = 0.0;
292        }
293        for (i, &row_u) in u.iter().enumerate() {
294            let col = self.idx[i] as usize;
295            if col >= v.len() {
296                return Err(SchurDataError::DimensionMismatch);
297            }
298            v[col] += (self.val[i] as Number) * row_u;
299        }
300        Ok(())
301    }
302}
303
304#[cfg(test)]
305mod tests {
306    use super::*;
307
308    /// Three-variable example: select variables 1 and 3 with sign +1.
309    /// `B = [[0 1 0 0], [0 0 0 1]]`.
310    #[test]
311    fn set_from_flags_round_trip() {
312        let mut s = IndexSchurData::new();
313        let flags = [0, 1, 0, 1];
314        s.set_from_flags(&flags, 1.0).expect("init");
315        assert_eq!(s.nrows(), 2);
316        assert_eq!(s.col_indices(), &[1, 3]);
317        assert_eq!(s.signs(), &[1, 1]);
318        assert!(s.is_initialized());
319    }
320
321    #[test]
322    fn set_from_flags_negative_sign_records_minus_one() {
323        let mut s = IndexSchurData::new();
324        s.set_from_flags(&[1, 0, 1], -2.5).expect("init");
325        assert_eq!(s.signs(), &[-1, -1]);
326    }
327
328    #[test]
329    fn set_from_flags_rejects_double_init() {
330        let mut s = IndexSchurData::new();
331        s.set_from_flags(&[1, 0], 1.0).expect("first init");
332        assert_eq!(
333            s.set_from_flags(&[0, 1], 1.0),
334            Err(SchurDataError::AlreadyInitialized),
335        );
336    }
337
338    #[test]
339    fn set_from_flags_rejects_zero_sign() {
340        let mut s = IndexSchurData::new();
341        assert_eq!(
342            s.set_from_flags(&[1, 0, 1], 0.0),
343            Err(SchurDataError::ZeroSign),
344        );
345    }
346
347    #[test]
348    fn set_from_flags_rejects_invalid_flag_with_distinct_variant() {
349        // A flag value other than 0/1 is bad *input*, not a double-init —
350        // it must surface as `InvalidFlag`, not `AlreadyInitialized`.
351        let mut s = IndexSchurData::new();
352        assert_eq!(
353            s.set_from_flags(&[1, 0, 2, 1], 1.0),
354            Err(SchurDataError::InvalidFlag),
355        );
356    }
357
358    #[test]
359    fn set_from_flags_invalid_flag_leaves_instance_pristine_for_retry() {
360        // The bad flag sits AFTER a valid `1`, so a non-atomic
361        // implementation would have already pushed row 0 before bailing.
362        // The instance must be left untouched (uninitialized, empty) so a
363        // corrected retry produces exactly the right rows — no duplicates.
364        let mut s = IndexSchurData::new();
365        assert_eq!(
366            s.set_from_flags(&[1, 0, 5], 1.0),
367            Err(SchurDataError::InvalidFlag),
368        );
369        assert!(
370            !s.is_initialized(),
371            "must stay uninitialized after a failed set"
372        );
373        assert_eq!(s.nrows(), 0, "no partial rows may linger");
374        assert_eq!(s.col_indices(), &[] as &[Index]);
375
376        // Corrected retry: selects vars 0 and 2 only.
377        s.set_from_flags(&[1, 0, 1], 1.0).expect("retry init");
378        assert_eq!(
379            s.col_indices(),
380            &[0, 2],
381            "retry must not append to leftover state"
382        );
383        assert_eq!(s.signs(), &[1, 1]);
384        assert!(s.is_initialized());
385    }
386
387    #[test]
388    fn set_from_list_records_each_column_once() {
389        let mut s = IndexSchurData::new();
390        s.set_from_list(&[2, 0, 4], 1.0).expect("init");
391        assert_eq!(s.nrows(), 3);
392        assert_eq!(s.col_indices(), &[2, 0, 4]);
393        assert_eq!(s.signs(), &[1, 1, 1]);
394    }
395
396    #[test]
397    fn from_parts_validates_signs() {
398        // Mixed ±1 OK
399        let ok = IndexSchurData::from_parts(vec![0, 2], vec![1, -1]).expect("ok");
400        assert_eq!(ok.signs(), &[1, -1]);
401        // Length mismatch
402        assert_eq!(
403            IndexSchurData::from_parts(vec![0, 2], vec![1]),
404            Err(SchurDataError::DimensionMismatch),
405        );
406        // Non-±1
407        assert_eq!(
408            IndexSchurData::from_parts(vec![0], vec![2]),
409            Err(SchurDataError::ZeroSign),
410        );
411    }
412
413    #[test]
414    fn multiply_picks_selected_columns_with_signs() {
415        // B = [[0 1 0 0], [0 0 0 -1]]
416        let s = IndexSchurData::from_parts(vec![1, 3], vec![1, -1]).expect("ok");
417        let v = [10.0, 20.0, 30.0, 40.0];
418        let mut u = [0.0; 2];
419        s.multiply(&v, &mut u).expect("ok");
420        // u[0] = +1·v[1] = 20
421        // u[1] = -1·v[3] = -40
422        assert_eq!(u, [20.0, -40.0]);
423    }
424
425    #[test]
426    fn trans_multiply_scatters_with_signs() {
427        // B from previous test; Bᵀ u with u = [3, 5] should produce
428        // v = (0, 3, 0, -5).
429        let s = IndexSchurData::from_parts(vec![1, 3], vec![1, -1]).expect("ok");
430        let u = [3.0, 5.0];
431        let mut v = [0.0; 4];
432        s.trans_multiply(&u, &mut v).expect("ok");
433        assert_eq!(v, [0.0, 3.0, 0.0, -5.0]);
434    }
435
436    #[test]
437    fn trans_multiply_overwrites_caller_buffer() {
438        // Existing entries in v are zeroed before the scatter.
439        let s = IndexSchurData::from_parts(vec![0, 2], vec![1, 1]).expect("ok");
440        let u = [1.0, 2.0];
441        let mut v = [99.0, 99.0, 99.0, 99.0];
442        s.trans_multiply(&u, &mut v).expect("ok");
443        assert_eq!(v, [1.0, 0.0, 2.0, 0.0]);
444    }
445
446    #[test]
447    fn multiply_rejects_uninitialized() {
448        let s = IndexSchurData::new();
449        let v = [0.0];
450        let mut u = [0.0];
451        assert_eq!(s.multiply(&v, &mut u), Err(SchurDataError::NotInitialized),);
452    }
453
454    #[test]
455    fn multiplying_row_out_of_range() {
456        let s = IndexSchurData::from_parts(vec![0], vec![1]).expect("ok");
457        assert_eq!(s.multiplying_row(2), Err(SchurDataError::RowOutOfRange),);
458    }
459
460    #[test]
461    fn multiplying_row_returns_single_entry_for_index_schur_data() {
462        // Per upstream `SensIndexSchurData.cpp:199-212`, `IndexSchurData`
463        // rows always have exactly one non-zero entry — pounce mirrors
464        // that contract.
465        let s = IndexSchurData::from_parts(vec![5, 7], vec![1, -1]).expect("ok");
466        let (idxs, facs) = s.multiplying_row(0).expect("ok");
467        assert_eq!(idxs, &[5]);
468        assert_eq!(facs, &[1.0]);
469        let (idxs, facs) = s.multiplying_row(1).expect("ok");
470        assert_eq!(idxs, &[7]);
471        assert_eq!(facs, &[-1.0]);
472    }
473}