Skip to main content

pounce_sensitivity/
schur_data.rs

1//! `SchurData` trait surface and the `IndexSchurData` flavor.
2//!
3//! Direct port of upstream
4//! [`SensSchurData.hpp`](../../../ref/Ipopt/contrib/sIPOPT/src/SensSchurData.hpp)
5//! (interface) and
6//! [`SensIndexSchurData.{hpp,cpp}`](../../../ref/Ipopt/contrib/sIPOPT/src/SensIndexSchurData.hpp)
7//! (the index-only specialization).
8//!
9//! # What `SchurData` represents
10//!
11//! `SchurData` is the matrix `B` in the augmented system
12//!
13//! ```text
14//! ⎡ K   A ⎤
15//! ⎣ B   0 ⎦
16//! ```
17//!
18//! used by the sIPOPT step calculation (Pirnay, López-Negrete & Biegler 2012,
19//! §2, eq. 4). `K` is the converged KKT matrix from the original IPM solve;
20//! `A` and `B` carry the parameter-perturbation rows. The trait is the
21//! minimum surface every backend needs to expose so the
22//! `PCalculator` / `SchurDriver` family can stay matrix-shape-agnostic
23//! ([`SensSchurData.hpp:17-25`](../../../ref/Ipopt/contrib/sIPOPT/src/SensSchurData.hpp)).
24//!
25//! # `IndexSchurData`
26//!
27//! Specialization for parameter rows whose only non-zero entries are
28//! ±1 ([`SensIndexSchurData.hpp:15-19`](../../../ref/Ipopt/contrib/sIPOPT/src/SensIndexSchurData.hpp)).
29//! Most parametric / reduced-Hessian use cases fit this shape — the
30//! parameter just picks out a subset of primal/dual variables — so the
31//! ±1 sparsification is what production sIPOPT runs on. Pounce's
32//! port stores parallel `Vec<i32>` arrays of indices and signs, same
33//! as upstream's `idx_` / `val_`
34//! ([`SensIndexSchurData.hpp:127-128`](../../../ref/Ipopt/contrib/sIPOPT/src/SensIndexSchurData.hpp)).
35
36use pounce_common::types::{Index, Number};
37
38/// Minimum surface for any matrix that lives in the augmented sIPOPT
39/// system's `A` / `B` slots. The numerical drivers in this crate
40/// (`PCalculator`, `SchurDriver`, `SensStepCalc`) consume `SchurData`
41/// objects and never touch the storage shape directly.
42///
43/// Mirrors `Ipopt::SchurData` from
44/// [`SensSchurData.hpp:29-178`](../../../ref/Ipopt/contrib/sIPOPT/src/SensSchurData.hpp).
45///
46/// # Lifecycle
47///
48/// A `SchurData` instance starts uninitialized. Exactly one of the
49/// `set_*` methods is called to populate it; subsequent reads via
50/// `nrows`, `multiply`, `trans_multiply`, etc. require an
51/// initialized instance. Upstream enforces this via `Set_Initialized`
52/// asserts in DBG builds
53/// ([`SensIndexSchurData.cpp:59-60`](../../../ref/Ipopt/contrib/sIPOPT/src/SensIndexSchurData.cpp));
54/// pounce mirrors that invariant via `Result` returns on the read
55/// surface so a mis-ordered call surfaces as `Err(SchurDataError::NotInitialized)`
56/// instead of a panic.
57pub trait SchurData {
58    /// Number of rows the schur matrix has, i.e. the row count of `B`.
59    /// Upstream `GetNRowsAdded()`
60    /// ([`SensSchurData.hpp:77-80`](../../../ref/Ipopt/contrib/sIPOPT/src/SensSchurData.hpp)).
61    fn nrows(&self) -> Index;
62
63    /// `true` if one of the `set_*` methods has been called and the
64    /// rest of the surface is safe to call. Upstream `Is_Initialized()`
65    /// ([`SensSchurData.hpp:82-85`](../../../ref/Ipopt/contrib/sIPOPT/src/SensSchurData.hpp)).
66    fn is_initialized(&self) -> bool;
67
68    /// Set rows from a 0/1 flag array of length `dim`. For each `i`
69    /// with `flags[i] == 1`, add a row whose only non-zero column is
70    /// `i` with sign `sign(v)`. The magnitude of `v` is collapsed to
71    /// ±1 — this trait only ever stores signs, mirroring upstream's
72    /// `SetData_Flag(dim, flags, v)`
73    /// ([`SensSchurData.hpp:45-49`](../../../ref/Ipopt/contrib/sIPOPT/src/SensSchurData.hpp),
74    /// [`SensIndexSchurData.cpp:51-78`](../../../ref/Ipopt/contrib/sIPOPT/src/SensIndexSchurData.cpp)).
75    ///
76    /// Returns `Err(_)` if the instance was already initialized.
77    fn set_from_flags(&mut self, flags: &[Index], v: Number) -> Result<(), SchurDataError>;
78
79    /// Set rows from a list of column indices. Each `cols[k]` becomes
80    /// row `k` of `B`, with the single non-zero entry at column
81    /// `cols[k]` carrying sign `sign(v)`. Upstream `SetData_List`
82    /// ([`SensSchurData.hpp:64-67`](../../../ref/Ipopt/contrib/sIPOPT/src/SensSchurData.hpp),
83    /// [`SensIndexSchurData.cpp:149-167`](../../../ref/Ipopt/contrib/sIPOPT/src/SensIndexSchurData.cpp)).
84    ///
85    /// Returns `Err(_)` if already initialized.
86    fn set_from_list(&mut self, cols: &[Index], v: Number) -> Result<(), SchurDataError>;
87
88    /// Row-`i` access: return the parallel arrays `(indices, factors)`
89    /// such that row `i` of `B` has non-zero column entries
90    /// `factors[j]` at columns `indices[j]`. For `IndexSchurData`
91    /// `indices` has length 1 and `factors[0] == ±1.0`. Upstream
92    /// `GetMultiplyingVectors`
93    /// ([`SensSchurData.hpp:93-104`](../../../ref/Ipopt/contrib/sIPOPT/src/SensSchurData.hpp),
94    /// [`SensIndexSchurData.cpp:199-212`](../../../ref/Ipopt/contrib/sIPOPT/src/SensIndexSchurData.cpp)).
95    fn multiplying_row(&self, i: Index) -> Result<(Vec<Index>, Vec<Number>), SchurDataError>;
96
97    /// Apply `u = B v` for a `v` of length `n_full` and pre-sized
98    /// `u` buffer (length must equal `self.nrows()`). Upstream
99    /// `Multiply(IteratesVector, Vector)`
100    /// ([`SensSchurData.hpp:106-110`](../../../ref/Ipopt/contrib/sIPOPT/src/SensSchurData.hpp),
101    /// [`SensIndexSchurData.cpp:214-251`](../../../ref/Ipopt/contrib/sIPOPT/src/SensIndexSchurData.cpp)).
102    ///
103    /// In pounce we operate on flat `&[Number]` instead of upstream's
104    /// `IteratesVector` block layout because Phase-A `SchurData` is
105    /// shape-agnostic. Phase-B reconstructs the block layout where
106    /// the algorithm-side needs it.
107    fn multiply(&self, v: &[Number], u: &mut [Number]) -> Result<(), SchurDataError>;
108
109    /// Apply `v = Bᵀ u` for a `u` of length `self.nrows()` and a
110    /// pre-sized `v` buffer (length `n_full`). Upstream
111    /// `TransMultiply`
112    /// ([`SensSchurData.hpp:112-116`](../../../ref/Ipopt/contrib/sIPOPT/src/SensSchurData.hpp),
113    /// [`SensIndexSchurData.cpp:253-307`](../../../ref/Ipopt/contrib/sIPOPT/src/SensIndexSchurData.cpp)).
114    fn trans_multiply(&self, u: &[Number], v: &mut [Number]) -> Result<(), SchurDataError>;
115}
116
117/// Failure modes returned by [`SchurData`] read/write entry points.
118/// Pounce returns these as `Err(_)` where upstream's debug asserts
119/// would `DBG_ASSERT`.
120#[derive(Debug, Clone, Copy, PartialEq, Eq)]
121pub enum SchurDataError {
122    /// A read method was called before a `set_*` method initialized
123    /// the instance. Upstream asserts `Is_Initialized()` in DBG builds
124    /// (e.g. [`SensIndexSchurData.cpp:176`](../../../ref/Ipopt/contrib/sIPOPT/src/SensIndexSchurData.cpp)).
125    NotInitialized,
126    /// `set_*` called twice on the same instance, or `flags` contained
127    /// values other than 0/1. Upstream:
128    /// [`SensIndexSchurData.cpp:59-69`](../../../ref/Ipopt/contrib/sIPOPT/src/SensIndexSchurData.cpp).
129    AlreadyInitialized,
130    /// A row index was out of range (e.g. `multiplying_row(i)` with
131    /// `i >= nrows()`).
132    RowOutOfRange,
133    /// Caller-supplied buffer length didn't match the expected shape.
134    DimensionMismatch,
135    /// `v == 0` passed to a sign-only `set_*` API. Upstream asserts
136    /// `v != 0` ([`SensIndexSchurData.cpp:61`](../../../ref/Ipopt/contrib/sIPOPT/src/SensIndexSchurData.cpp)).
137    ZeroSign,
138}
139
140/// Specialization for `B` matrices whose non-zero entries are ±1
141/// (the parametric / reduced-Hessian common case). Storage is two
142/// parallel arrays mirroring upstream's `idx_` and `val_`
143/// ([`SensIndexSchurData.hpp:127-128`](../../../ref/Ipopt/contrib/sIPOPT/src/SensIndexSchurData.hpp)).
144#[derive(Debug, Clone, Default, PartialEq, Eq)]
145pub struct IndexSchurData {
146    idx: Vec<Index>,
147    /// Stored as ±1 values, matching upstream's `Index`-typed `val_`.
148    /// Kept as `Index` (not `Number`) so the sign is exact and small.
149    val: Vec<Index>,
150    initialized: bool,
151}
152
153impl IndexSchurData {
154    /// Empty / uninitialized instance. Must be populated via one of
155    /// the `set_from_*` methods before being read.
156    pub fn new() -> Self {
157        Self::default()
158    }
159
160    /// Construct directly from pre-built `(idx, val)` arrays. `val`
161    /// must contain only ±1 entries. Mirrors upstream's two-arg
162    /// constructor
163    /// ([`SensIndexSchurData.cpp:24-36`](../../../ref/Ipopt/contrib/sIPOPT/src/SensIndexSchurData.cpp)).
164    pub fn from_parts(idx: Vec<Index>, val: Vec<Index>) -> Result<Self, SchurDataError> {
165        if idx.len() != val.len() {
166            return Err(SchurDataError::DimensionMismatch);
167        }
168        if val.iter().any(|&v| v != 1 && v != -1) {
169            return Err(SchurDataError::ZeroSign);
170        }
171        Ok(Self {
172            idx,
173            val,
174            initialized: true,
175        })
176    }
177
178    /// Column indices the rows refer to (one index per row). Upstream
179    /// `GetColIndices()`
180    /// ([`SensIndexSchurData.hpp:114`](../../../ref/Ipopt/contrib/sIPOPT/src/SensIndexSchurData.hpp)).
181    pub fn col_indices(&self) -> &[Index] {
182        &self.idx
183    }
184
185    /// Per-row ±1 sign carried by the single non-zero entry in that
186    /// row. Pounce-specific accessor; upstream exposes the data only
187    /// through the `multiplying_*` / `multiply` APIs.
188    pub fn signs(&self) -> &[Index] {
189        &self.val
190    }
191}
192
193impl SchurData for IndexSchurData {
194    fn nrows(&self) -> Index {
195        self.val.len() as Index
196    }
197
198    fn is_initialized(&self) -> bool {
199        self.initialized
200    }
201
202    fn set_from_flags(&mut self, flags: &[Index], v: Number) -> Result<(), SchurDataError> {
203        if self.initialized {
204            return Err(SchurDataError::AlreadyInitialized);
205        }
206        if v == 0.0 {
207            return Err(SchurDataError::ZeroSign);
208        }
209        let w: Index = if v > 0.0 { 1 } else { -1 };
210        for (i, &f) in flags.iter().enumerate() {
211            match f {
212                0 => {}
213                1 => {
214                    self.idx.push(i as Index);
215                    self.val.push(w);
216                }
217                _ => return Err(SchurDataError::AlreadyInitialized), // upstream asserts flag ∈ {0,1}
218            }
219        }
220        self.initialized = true;
221        Ok(())
222    }
223
224    fn set_from_list(&mut self, cols: &[Index], v: Number) -> Result<(), SchurDataError> {
225        if self.initialized {
226            return Err(SchurDataError::AlreadyInitialized);
227        }
228        if v == 0.0 {
229            return Err(SchurDataError::ZeroSign);
230        }
231        let w: Index = if v > 0.0 { 1 } else { -1 };
232        self.idx.extend_from_slice(cols);
233        self.val.resize(cols.len(), w);
234        self.initialized = true;
235        Ok(())
236    }
237
238    fn multiplying_row(&self, i: Index) -> Result<(Vec<Index>, Vec<Number>), SchurDataError> {
239        if !self.initialized {
240            return Err(SchurDataError::NotInitialized);
241        }
242        let i_us = i as usize;
243        if i_us >= self.idx.len() {
244            return Err(SchurDataError::RowOutOfRange);
245        }
246        Ok((vec![self.idx[i_us]], vec![self.val[i_us] as Number]))
247    }
248
249    fn multiply(&self, v: &[Number], u: &mut [Number]) -> Result<(), SchurDataError> {
250        if !self.initialized {
251            return Err(SchurDataError::NotInitialized);
252        }
253        if u.len() != self.idx.len() {
254            return Err(SchurDataError::DimensionMismatch);
255        }
256        for (i, slot) in u.iter_mut().enumerate() {
257            let col = self.idx[i] as usize;
258            if col >= v.len() {
259                return Err(SchurDataError::DimensionMismatch);
260            }
261            *slot = (self.val[i] as Number) * v[col];
262        }
263        Ok(())
264    }
265
266    fn trans_multiply(&self, u: &[Number], v: &mut [Number]) -> Result<(), SchurDataError> {
267        if !self.initialized {
268            return Err(SchurDataError::NotInitialized);
269        }
270        if u.len() != self.idx.len() {
271            return Err(SchurDataError::DimensionMismatch);
272        }
273        for slot in v.iter_mut() {
274            *slot = 0.0;
275        }
276        for (i, &row_u) in u.iter().enumerate() {
277            let col = self.idx[i] as usize;
278            if col >= v.len() {
279                return Err(SchurDataError::DimensionMismatch);
280            }
281            v[col] += (self.val[i] as Number) * row_u;
282        }
283        Ok(())
284    }
285}
286
287#[cfg(test)]
288mod tests {
289    use super::*;
290
291    /// Three-variable example: select variables 1 and 3 with sign +1.
292    /// `B = [[0 1 0 0], [0 0 0 1]]`.
293    #[test]
294    fn set_from_flags_round_trip() {
295        let mut s = IndexSchurData::new();
296        let flags = [0, 1, 0, 1];
297        s.set_from_flags(&flags, 1.0).expect("init");
298        assert_eq!(s.nrows(), 2);
299        assert_eq!(s.col_indices(), &[1, 3]);
300        assert_eq!(s.signs(), &[1, 1]);
301        assert!(s.is_initialized());
302    }
303
304    #[test]
305    fn set_from_flags_negative_sign_records_minus_one() {
306        let mut s = IndexSchurData::new();
307        s.set_from_flags(&[1, 0, 1], -2.5).expect("init");
308        assert_eq!(s.signs(), &[-1, -1]);
309    }
310
311    #[test]
312    fn set_from_flags_rejects_double_init() {
313        let mut s = IndexSchurData::new();
314        s.set_from_flags(&[1, 0], 1.0).expect("first init");
315        assert_eq!(
316            s.set_from_flags(&[0, 1], 1.0),
317            Err(SchurDataError::AlreadyInitialized),
318        );
319    }
320
321    #[test]
322    fn set_from_flags_rejects_zero_sign() {
323        let mut s = IndexSchurData::new();
324        assert_eq!(
325            s.set_from_flags(&[1, 0, 1], 0.0),
326            Err(SchurDataError::ZeroSign),
327        );
328    }
329
330    #[test]
331    fn set_from_list_records_each_column_once() {
332        let mut s = IndexSchurData::new();
333        s.set_from_list(&[2, 0, 4], 1.0).expect("init");
334        assert_eq!(s.nrows(), 3);
335        assert_eq!(s.col_indices(), &[2, 0, 4]);
336        assert_eq!(s.signs(), &[1, 1, 1]);
337    }
338
339    #[test]
340    fn from_parts_validates_signs() {
341        // Mixed ±1 OK
342        let ok = IndexSchurData::from_parts(vec![0, 2], vec![1, -1]).expect("ok");
343        assert_eq!(ok.signs(), &[1, -1]);
344        // Length mismatch
345        assert_eq!(
346            IndexSchurData::from_parts(vec![0, 2], vec![1]),
347            Err(SchurDataError::DimensionMismatch),
348        );
349        // Non-±1
350        assert_eq!(
351            IndexSchurData::from_parts(vec![0], vec![2]),
352            Err(SchurDataError::ZeroSign),
353        );
354    }
355
356    #[test]
357    fn multiply_picks_selected_columns_with_signs() {
358        // B = [[0 1 0 0], [0 0 0 -1]]
359        let s = IndexSchurData::from_parts(vec![1, 3], vec![1, -1]).expect("ok");
360        let v = [10.0, 20.0, 30.0, 40.0];
361        let mut u = [0.0; 2];
362        s.multiply(&v, &mut u).expect("ok");
363        // u[0] = +1·v[1] = 20
364        // u[1] = -1·v[3] = -40
365        assert_eq!(u, [20.0, -40.0]);
366    }
367
368    #[test]
369    fn trans_multiply_scatters_with_signs() {
370        // B from previous test; Bᵀ u with u = [3, 5] should produce
371        // v = (0, 3, 0, -5).
372        let s = IndexSchurData::from_parts(vec![1, 3], vec![1, -1]).expect("ok");
373        let u = [3.0, 5.0];
374        let mut v = [0.0; 4];
375        s.trans_multiply(&u, &mut v).expect("ok");
376        assert_eq!(v, [0.0, 3.0, 0.0, -5.0]);
377    }
378
379    #[test]
380    fn trans_multiply_overwrites_caller_buffer() {
381        // Existing entries in v are zeroed before the scatter.
382        let s = IndexSchurData::from_parts(vec![0, 2], vec![1, 1]).expect("ok");
383        let u = [1.0, 2.0];
384        let mut v = [99.0, 99.0, 99.0, 99.0];
385        s.trans_multiply(&u, &mut v).expect("ok");
386        assert_eq!(v, [1.0, 0.0, 2.0, 0.0]);
387    }
388
389    #[test]
390    fn multiply_rejects_uninitialized() {
391        let s = IndexSchurData::new();
392        let v = [0.0];
393        let mut u = [0.0];
394        assert_eq!(s.multiply(&v, &mut u), Err(SchurDataError::NotInitialized),);
395    }
396
397    #[test]
398    fn multiplying_row_out_of_range() {
399        let s = IndexSchurData::from_parts(vec![0], vec![1]).expect("ok");
400        assert_eq!(s.multiplying_row(2), Err(SchurDataError::RowOutOfRange),);
401    }
402
403    #[test]
404    fn multiplying_row_returns_single_entry_for_index_schur_data() {
405        // Per upstream `SensIndexSchurData.cpp:199-212`, `IndexSchurData`
406        // rows always have exactly one non-zero entry — pounce mirrors
407        // that contract.
408        let s = IndexSchurData::from_parts(vec![5, 7], vec![1, -1]).expect("ok");
409        let (idxs, facs) = s.multiplying_row(0).expect("ok");
410        assert_eq!(idxs, &[5]);
411        assert_eq!(facs, &[1.0]);
412        let (idxs, facs) = s.multiplying_row(1).expect("ok");
413        assert_eq!(idxs, &[7]);
414        assert_eq!(facs, &[-1.0]);
415    }
416}