Skip to main content

pounce_algorithm/kkt/
std_aug_system_solver.rs

1//! Standard augmented-system solver — port of
2//! `Algorithm/IpStdAugSystemSolver.{hpp,cpp}`.
3//!
4//! Flattens the four-block KKT matrix into a single lower-triangular
5//! 1-based triplet and hands it to a [`pounce_linsol::TSymLinearSolver`].
6//! On the first call the structure is computed (and the linsol's
7//! `initialize_structure` is invoked); subsequent calls only refill the
8//! values array and call `multi_solve`. Matches the cache/skip logic in
9//! upstream `IpStdAugSystemSolver::CreateAugmentedSpace` and
10//! `CreateAugmentedSystem`.
11//!
12//! Sign convention follows upstream:
13//!
14//! ```text
15//!   (1,1) = w_factor·W + diag(D_x + δ_x)
16//!   (2,2) = diag(D_s + δ_s)
17//!   (3,1) = J_c
18//!   (3,3) = -diag(D_c + δ_c)
19//!   (4,1) = J_d
20//!   (4,2) = -I
21//!   (4,4) = -diag(D_d + δ_d)
22//! ```
23//!
24//! Phase-6 first cut: assumes `W` is a [`SymTMatrix`], `J_c`/`J_d` are
25//! [`GenTMatrix`], and `D_*` are [`DenseVector`]s — the only concrete
26//! types `OrigIpoptNLP` produces. CompoundMatrix/CompoundVector
27//! flattening (used by L-BFGS in Phase 8) is deferred.
28
29use crate::kkt::aug_system_solver::{AugSysCoeffs, AugSysRhs, AugSysSol, AugSystemSolver};
30use pounce_common::diagnostics::{DiagCategory, DiagnosticsState};
31use pounce_common::timing::TimingStatistics;
32use pounce_common::types::{Index, Number};
33use pounce_linalg::compound_vector::CompoundVector;
34use pounce_linalg::dense_vector::DenseVector;
35use pounce_linalg::triplet::{GenTMatrix, SymTMatrix};
36use pounce_linalg::Vector;
37use pounce_linsol::{ESymSolverStatus, SymLinearSolver, TSymLinearSolver};
38use std::ops::Range;
39use std::rc::Rc;
40
41/// Standard augmented-system solver.
42pub struct StdAugSystemSolver {
43    linsol: TSymLinearSolver,
44
45    /// `true` once the triplet structure has been pinned.
46    initialized: bool,
47    n_x: Index,
48    n_s: Index,
49    n_c: Index,
50    n_d: Index,
51    /// Total dim = `n_x + n_s + n_c + n_d`.
52    dim: Index,
53
54    /// 1-based row indices, length = total triplet nnz.
55    irn: Vec<Index>,
56    /// 1-based col indices.
57    jcn: Vec<Index>,
58    /// Working values array reused across calls.
59    vals: Vec<Number>,
60
61    // Per-block ranges into `vals` / `irn` / `jcn`.
62    w_range: Range<usize>,
63    dx_range: Range<usize>,
64    ds_range: Range<usize>,
65    jc_range: Range<usize>,
66    dc_range: Range<usize>,
67    jd_range: Range<usize>,
68    minus_i_range: Range<usize>,
69    dd_range: Range<usize>,
70
71    last_neg_evals: Index,
72    last_status: Option<ESymSolverStatus>,
73
74    /// `true` once a successful `solve()` has been completed since the
75    /// last reinitialisation or `increase_quality`. Required precondition
76    /// for `resolve()` (back-substitution against the cached factor).
77    have_factor: bool,
78
79    /// Shared per-solve timing accumulator. `None` until the
80    /// algorithm installs it via [`AugSystemSolver::set_timing_stats`];
81    /// when `None`, both `solve` and `resolve` skip the timing bumps.
82    timing: Option<Rc<TimingStatistics>>,
83
84    /// Shared per-solve diagnostics state. `None` unless the
85    /// application requested KKT dumps via the CLI's `--dump` flag.
86    /// When set, every successful `solve()` may emit a JSONL record
87    /// to `<dump_dir>/iter_NNN/kkt_solve_MMM.jsonl`, gated by the
88    /// configured iter-spec for [`DiagCategory::Kkt`].
89    diagnostics: Option<Rc<DiagnosticsState>>,
90}
91
92impl std::fmt::Debug for StdAugSystemSolver {
93    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
94        f.debug_struct("StdAugSystemSolver")
95            .field("dim", &self.dim)
96            .field("nnz", &self.vals.len())
97            .field("initialized", &self.initialized)
98            .field("last_neg_evals", &self.last_neg_evals)
99            .field("last_status", &self.last_status)
100            .finish_non_exhaustive()
101    }
102}
103
104impl StdAugSystemSolver {
105    /// Build a solver around a configured [`TSymLinearSolver`].
106    pub fn new(linsol: TSymLinearSolver) -> Self {
107        Self {
108            linsol,
109            initialized: false,
110            n_x: 0,
111            n_s: 0,
112            n_c: 0,
113            n_d: 0,
114            dim: 0,
115            irn: Vec::new(),
116            jcn: Vec::new(),
117            vals: Vec::new(),
118            w_range: 0..0,
119            dx_range: 0..0,
120            ds_range: 0..0,
121            jc_range: 0..0,
122            dc_range: 0..0,
123            jd_range: 0..0,
124            minus_i_range: 0..0,
125            dd_range: 0..0,
126            last_neg_evals: 0,
127            last_status: None,
128            have_factor: false,
129            timing: None,
130            diagnostics: None,
131        }
132    }
133
134    fn build_structure(&mut self, coeffs: &AugSysCoeffs<'_>) -> ESymSolverStatus {
135        let n_x = coeffs.j_c.n_cols();
136        let n_c = coeffs.j_c.n_rows();
137        let n_d = coeffs.j_d.n_rows();
138        debug_assert_eq!(coeffs.j_d.n_cols(), n_x);
139        let n_s = n_d;
140
141        let w_nnz = match coeffs.w {
142            None => 0_usize,
143            Some(w) => sym_t_downcast(w).nonzeros() as usize,
144        };
145        let jc_nnz = gen_t_downcast(coeffs.j_c).nonzeros() as usize;
146        let jd_nnz = gen_t_downcast(coeffs.j_d).nonzeros() as usize;
147
148        let total = w_nnz
149            + (n_x as usize) // dx diagonal
150            + (n_s as usize) // ds diagonal
151            + jc_nnz
152            + (n_c as usize) // dc diagonal (negative)
153            + jd_nnz
154            + (n_s as usize) // -I block
155            + (n_d as usize); // dd diagonal (negative)
156
157        self.irn = Vec::with_capacity(total);
158        self.jcn = Vec::with_capacity(total);
159        self.vals = vec![0.0; total];
160
161        // ---- (1,1) block: W ----
162        let w_start = self.irn.len();
163        if let Some(w) = coeffs.w {
164            let w = sym_t_downcast(w);
165            self.irn.extend_from_slice(w.irows());
166            self.jcn.extend_from_slice(w.jcols());
167        }
168        self.w_range = w_start..self.irn.len();
169
170        // ---- (1,1) diagonal: D_x + δ_x ----
171        let dx_start = self.irn.len();
172        for i in 0..n_x {
173            self.irn.push(i + 1);
174            self.jcn.push(i + 1);
175        }
176        self.dx_range = dx_start..self.irn.len();
177
178        // ---- (2,2) diagonal: D_s + δ_s ----
179        let ds_start = self.irn.len();
180        for i in 0..n_s {
181            let r = n_x + i + 1;
182            self.irn.push(r);
183            self.jcn.push(r);
184        }
185        self.ds_range = ds_start..self.irn.len();
186
187        // ---- (3,1) block: J_c ----
188        let jc_start = self.irn.len();
189        let j_c = gen_t_downcast(coeffs.j_c);
190        let row_off_c = n_x + n_s;
191        for (&i, &j) in j_c.irows().iter().zip(j_c.jcols().iter()) {
192            // Upstream rows/cols are 1-based already; remap row to the
193            // (3,_) compound block.
194            self.irn.push(row_off_c + i);
195            self.jcn.push(j);
196        }
197        self.jc_range = jc_start..self.irn.len();
198
199        // ---- (3,3) diagonal: -(D_c + δ_c) ----
200        let dc_start = self.irn.len();
201        for i in 0..n_c {
202            let r = n_x + n_s + i + 1;
203            self.irn.push(r);
204            self.jcn.push(r);
205        }
206        self.dc_range = dc_start..self.irn.len();
207
208        // ---- (4,1) block: J_d ----
209        let jd_start = self.irn.len();
210        let j_d = gen_t_downcast(coeffs.j_d);
211        let row_off_d = n_x + n_s + n_c;
212        for (&i, &j) in j_d.irows().iter().zip(j_d.jcols().iter()) {
213            self.irn.push(row_off_d + i);
214            self.jcn.push(j);
215        }
216        self.jd_range = jd_start..self.irn.len();
217
218        // ---- (4,2) block: -I ----
219        let mi_start = self.irn.len();
220        for i in 0..n_s {
221            self.irn.push(n_x + n_s + n_c + i + 1);
222            self.jcn.push(n_x + i + 1);
223        }
224        self.minus_i_range = mi_start..self.irn.len();
225
226        // ---- (4,4) diagonal: -(D_d + δ_d) ----
227        let dd_start = self.irn.len();
228        for i in 0..n_d {
229            let r = n_x + n_s + n_c + i + 1;
230            self.irn.push(r);
231            self.jcn.push(r);
232        }
233        self.dd_range = dd_start..self.irn.len();
234
235        debug_assert_eq!(self.irn.len(), total);
236        debug_assert_eq!(self.jcn.len(), total);
237
238        self.n_x = n_x;
239        self.n_s = n_s;
240        self.n_c = n_c;
241        self.n_d = n_d;
242        self.dim = n_x + n_s + n_c + n_d;
243
244        let status = self
245            .linsol
246            .initialize_structure(self.dim, &self.irn, &self.jcn);
247        if status == ESymSolverStatus::Success {
248            self.initialized = true;
249        }
250        status
251    }
252
253    fn refill_values(&mut self, coeffs: &AugSysCoeffs<'_>) {
254        // (1,1) W
255        if !self.w_range.is_empty() {
256            let Some(w_dyn) = coeffs.w else {
257                unreachable!("structure pinned with W; W cannot be None now")
258            };
259            let w = sym_t_downcast(w_dyn);
260            let dst = &mut self.vals[self.w_range.clone()];
261            for (d, &v) in dst.iter_mut().zip(w.values().iter()) {
262                *d = coeffs.w_factor * v;
263            }
264        }
265        // (1,1) diag: D_x + δ_x
266        fill_diag(
267            &mut self.vals[self.dx_range.clone()],
268            coeffs.d_x,
269            coeffs.delta_x,
270            1.0,
271        );
272        // (2,2) diag: D_s + δ_s
273        fill_diag(
274            &mut self.vals[self.ds_range.clone()],
275            coeffs.d_s,
276            coeffs.delta_s,
277            1.0,
278        );
279        // (3,1) J_c
280        {
281            let j_c = gen_t_downcast(coeffs.j_c);
282            self.vals[self.jc_range.clone()].copy_from_slice(j_c.values());
283        }
284        // (3,3) diag: -(D_c + δ_c)
285        fill_diag(
286            &mut self.vals[self.dc_range.clone()],
287            coeffs.d_c,
288            coeffs.delta_c,
289            -1.0,
290        );
291        // (4,1) J_d
292        {
293            let j_d = gen_t_downcast(coeffs.j_d);
294            self.vals[self.jd_range.clone()].copy_from_slice(j_d.values());
295        }
296        // (4,2) -I
297        for v in &mut self.vals[self.minus_i_range.clone()] {
298            *v = -1.0;
299        }
300        // (4,4) diag: -(D_d + δ_d)
301        fill_diag(
302            &mut self.vals[self.dd_range.clone()],
303            coeffs.d_d,
304            coeffs.delta_d,
305            -1.0,
306        );
307    }
308
309    fn pack_rhs(&self, rhs: &AugSysRhs<'_>, packed: &mut [Number]) {
310        let n_x = self.n_x as usize;
311        let n_s = self.n_s as usize;
312        let n_c = self.n_c as usize;
313        let n_d = self.n_d as usize;
314        copy_vec(rhs.rhs_x, &mut packed[..n_x]);
315        copy_vec(rhs.rhs_s, &mut packed[n_x..n_x + n_s]);
316        copy_vec(rhs.rhs_c, &mut packed[n_x + n_s..n_x + n_s + n_c]);
317        copy_vec(
318            rhs.rhs_d,
319            &mut packed[n_x + n_s + n_c..n_x + n_s + n_c + n_d],
320        );
321    }
322
323    fn unpack_sol(&self, packed: &[Number], sol: &mut AugSysSol<'_>) {
324        let n_x = self.n_x as usize;
325        let n_s = self.n_s as usize;
326        let n_c = self.n_c as usize;
327        let n_d = self.n_d as usize;
328        write_vec(sol.sol_x, &packed[..n_x]);
329        write_vec(sol.sol_s, &packed[n_x..n_x + n_s]);
330        write_vec(sol.sol_c, &packed[n_x + n_s..n_x + n_s + n_c]);
331        write_vec(sol.sol_d, &packed[n_x + n_s + n_c..n_x + n_s + n_c + n_d]);
332    }
333}
334
335impl AugSystemSolver for StdAugSystemSolver {
336    fn provides_inertia(&self) -> bool {
337        self.linsol.provides_inertia()
338    }
339
340    fn number_of_neg_evals(&self) -> Index {
341        self.last_neg_evals
342    }
343
344    fn increase_quality(&mut self) -> bool {
345        // Quality bump → pivtol changed → next solve must refactor.
346        // `resolve` would silently hand back stale numbers; force the
347        // full path by invalidating the cached-factor flag here.
348        self.have_factor = false;
349        self.linsol.increase_quality()
350    }
351
352    fn last_solve_status(&self) -> ESymSolverStatus {
353        self.last_status.unwrap_or(ESymSolverStatus::FatalError)
354    }
355
356    fn solve(
357        &mut self,
358        coeffs: &AugSysCoeffs<'_>,
359        rhs: &AugSysRhs<'_>,
360        sol: &mut AugSysSol<'_>,
361        check_neg_evals: bool,
362        num_neg_evals: Index,
363    ) -> ESymSolverStatus {
364        if !self.initialized {
365            let s = self.build_structure(coeffs);
366            if s != ESymSolverStatus::Success {
367                self.last_status = Some(s);
368                return s;
369            }
370        }
371        self.refill_values(coeffs);
372
373        let mut packed = vec![0.0; self.dim as usize];
374        self.pack_rhs(rhs, &mut packed);
375
376        let dump_rhs = packed.clone();
377
378        // Attributes the whole factor+back-solve to
379        // `linear_system_factorization` (mirrors upstream
380        // `IpStdAugSystemSolver.cpp:155`).
381        let _factor_guard = self
382            .timing
383            .as_deref()
384            .map(|t| t.linear_system_factorization.guard());
385        let status = self.linsol.multi_solve(
386            &self.vals,
387            true,
388            1,
389            &mut packed,
390            check_neg_evals,
391            num_neg_evals,
392        );
393        drop(_factor_guard);
394        self.last_status = Some(status);
395        if status == ESymSolverStatus::Success {
396            if self.linsol.provides_inertia() {
397                self.last_neg_evals = self.linsol.number_of_neg_evals();
398            }
399            self.unpack_sol(&packed, sol);
400            self.have_factor = true;
401        }
402
403        // Diagnostic dump: structured `--dump kkt:...` surface, then
404        // the legacy `POUNCE_DUMP_KKT=<path>` env-var fallback. The
405        // two paths share `write_kkt_record` so the JSON line is bit-
406        // identical regardless of how the dump was requested.
407        if let Some(diag) = self.diagnostics.clone() {
408            if diag.want(DiagCategory::Kkt) {
409                let solve_idx = diag.next_solve_index();
410                let filename = format!("kkt_solve_{solve_idx:03}.jsonl");
411                if let Some(mut w) = diag.open_writer(&filename) {
412                    let _ = write_kkt_record(
413                        &mut w,
414                        self.dim,
415                        &self.irn,
416                        &self.jcn,
417                        &self.vals,
418                        &dump_rhs,
419                        &packed,
420                        check_neg_evals,
421                        num_neg_evals,
422                        status,
423                        self.last_neg_evals,
424                    );
425                }
426            }
427        }
428        if let Ok(path) = std::env::var("POUNCE_DUMP_KKT") {
429            use std::sync::atomic::{AtomicBool, Ordering};
430            static WARNED: AtomicBool = AtomicBool::new(false);
431            if !WARNED.swap(true, Ordering::SeqCst) {
432                eprintln!(
433                    "warning: POUNCE_DUMP_KKT is deprecated; prefer `--dump kkt:<iter-spec>` (see pounce --help)"
434                );
435            }
436            dump_kkt(
437                &path,
438                self.dim,
439                &self.irn,
440                &self.jcn,
441                &self.vals,
442                &dump_rhs,
443                &packed,
444                check_neg_evals,
445                num_neg_evals,
446                status,
447                self.last_neg_evals,
448            );
449        }
450
451        status
452    }
453
454    fn resolve(
455        &mut self,
456        coeffs: &AugSysCoeffs<'_>,
457        rhs: &AugSysRhs<'_>,
458        sol: &mut AugSysSol<'_>,
459    ) -> ESymSolverStatus {
460        // Contract: caller has invoked `solve` with byte-identical
461        // coefficients since the last `increase_quality`. We trust
462        // them and reuse the cached factor. If `have_factor` is false
463        // (cold start, or quality was bumped), fall through to a
464        // full solve so correctness is preserved even when the call
465        // site misjudges the cache state.
466        if !self.have_factor {
467            return self.solve(coeffs, rhs, sol, false, 0);
468        }
469
470        let mut packed = vec![0.0; self.dim as usize];
471        self.pack_rhs(rhs, &mut packed);
472
473        // Back-substitution against the cached factor; mirrors upstream
474        // `IpStdAugSystemSolver.cpp` `linear_system_back_solve` task.
475        let _back_guard = self
476            .timing
477            .as_deref()
478            .map(|t| t.linear_system_back_solve.guard());
479        let status = self
480            .linsol
481            .multi_solve(&self.vals, false, 1, &mut packed, false, 0);
482        drop(_back_guard);
483        self.last_status = Some(status);
484        if status == ESymSolverStatus::Success {
485            self.unpack_sol(&packed, sol);
486        }
487        status
488    }
489
490    fn set_diagnostics(&mut self, diag: Rc<DiagnosticsState>) {
491        self.diagnostics = Some(diag);
492    }
493
494    fn set_timing_stats(&mut self, timing: Rc<TimingStatistics>) {
495        self.timing = Some(timing);
496    }
497}
498
499// ---------------- helpers ----------------
500
501#[allow(clippy::too_many_arguments)]
502/// Serialize one KKT solve as a single JSONL record. Shared by the
503/// `--dump kkt:...` path (one file per solve under `iter_NNN/`) and
504/// the legacy `POUNCE_DUMP_KKT` path (one append-mode file across
505/// the whole run).
506fn write_kkt_record(
507    w: &mut dyn std::io::Write,
508    dim: Index,
509    irn: &[Index],
510    jcn: &[Index],
511    vals: &[Number],
512    rhs: &[Number],
513    sol: &[Number],
514    check_neg_evals: bool,
515    num_neg_evals: Index,
516    status: ESymSolverStatus,
517    last_neg_evals: Index,
518) -> std::io::Result<()> {
519    use std::fmt::Write as _;
520
521    let mut line = String::with_capacity(64 * vals.len());
522    line.push('{');
523    let _ = write!(line, "\"n\":{dim},");
524    let _ = write!(line, "\"check_neg_evals\":{check_neg_evals},");
525    let _ = write!(line, "\"num_neg_evals_expected\":{num_neg_evals},");
526    let _ = write!(line, "\"num_neg_evals_actual\":{last_neg_evals},");
527    let _ = write!(line, "\"status\":\"{status:?}\",");
528
529    line.push_str("\"irn\":[");
530    for (i, v) in irn.iter().enumerate() {
531        if i > 0 {
532            line.push(',');
533        }
534        let _ = write!(line, "{v}");
535    }
536    line.push_str("],\"jcn\":[");
537    for (i, v) in jcn.iter().enumerate() {
538        if i > 0 {
539            line.push(',');
540        }
541        let _ = write!(line, "{v}");
542    }
543    line.push_str("],\"vals\":[");
544    for (i, v) in vals.iter().enumerate() {
545        if i > 0 {
546            line.push(',');
547        }
548        let _ = write!(line, "{v:.17e}");
549    }
550    line.push_str("],\"rhs\":[");
551    for (i, v) in rhs.iter().enumerate() {
552        if i > 0 {
553            line.push(',');
554        }
555        let _ = write!(line, "{v:.17e}");
556    }
557    line.push_str("],\"sol\":[");
558    for (i, v) in sol.iter().enumerate() {
559        if i > 0 {
560            line.push(',');
561        }
562        let _ = write!(line, "{v:.17e}");
563    }
564    line.push_str("]}\n");
565
566    w.write_all(line.as_bytes())
567}
568
569fn dump_kkt(
570    path: &str,
571    dim: Index,
572    irn: &[Index],
573    jcn: &[Index],
574    vals: &[Number],
575    rhs: &[Number],
576    sol: &[Number],
577    check_neg_evals: bool,
578    num_neg_evals: Index,
579    status: ESymSolverStatus,
580    last_neg_evals: Index,
581) {
582    if let Ok(mut f) = std::fs::OpenOptions::new()
583        .create(true)
584        .append(true)
585        .open(path)
586    {
587        let _ = write_kkt_record(
588            &mut f,
589            dim,
590            irn,
591            jcn,
592            vals,
593            rhs,
594            sol,
595            check_neg_evals,
596            num_neg_evals,
597            status,
598            last_neg_evals,
599        );
600    }
601}
602
603fn sym_t_downcast(m: &dyn pounce_linalg::SymMatrix) -> &SymTMatrix {
604    let Some(t) = m.as_any().downcast_ref::<SymTMatrix>() else {
605        unreachable!("StdAugSystemSolver: W must be a SymTMatrix in v1.0")
606    };
607    t
608}
609
610fn gen_t_downcast(m: &dyn pounce_linalg::Matrix) -> &GenTMatrix {
611    let Some(t) = m.as_any().downcast_ref::<GenTMatrix>() else {
612        unreachable!("StdAugSystemSolver: J_c / J_d must be GenTMatrix in v1.0")
613    };
614    t
615}
616
617/// Read a vector that is either a [`DenseVector`] or a
618/// [`CompoundVector`] of [`DenseVector`]s into a contiguous owned
619/// `Vec<Number>`. The resto-side IPM hands us 5-block compound x /
620/// D_x; v1.0 originals always arrive as `DenseVector`. Panics on any
621/// other layout.
622fn flat_read(v: &dyn Vector) -> Vec<Number> {
623    if let Some(dv) = v.as_any().downcast_ref::<DenseVector>() {
624        return dv.expanded_values();
625    }
626    if let Some(cv) = v.as_any().downcast_ref::<CompoundVector>() {
627        let mut out = Vec::with_capacity(cv.dim() as usize);
628        for k in 0..cv.n_comps() {
629            let blk = cv.comp(k);
630            let dblk = blk
631                .as_any()
632                .downcast_ref::<DenseVector>()
633                .expect("StdAugSystemSolver: CompoundVector blocks must be DenseVectors");
634            out.extend_from_slice(&dblk.expanded_values());
635        }
636        return out;
637    }
638    unreachable!("StdAugSystemSolver: D_*/rhs/sol must be DenseVector or CompoundVector of DenseVectors in v1.0")
639}
640
641/// Inverse of [`flat_read`].
642fn flat_write(dst: &mut dyn Vector, src: &[Number]) {
643    if let Some(dv) = dst.as_any_mut().downcast_mut::<DenseVector>() {
644        dv.set_values(src);
645        return;
646    }
647    if let Some(cv) = dst.as_any_mut().downcast_mut::<CompoundVector>() {
648        let mut off = 0usize;
649        for k in 0..cv.n_comps() {
650            let blk = cv.comp_mut(k);
651            let dim = blk.dim() as usize;
652            let dblk = blk
653                .as_any_mut()
654                .downcast_mut::<DenseVector>()
655                .expect("StdAugSystemSolver: CompoundVector blocks must be DenseVectors");
656            dblk.set_values(&src[off..off + dim]);
657            off += dim;
658        }
659        return;
660    }
661    unreachable!(
662        "StdAugSystemSolver: sol must be DenseVector or CompoundVector of DenseVectors in v1.0"
663    )
664}
665
666/// Write `sign · (D[i] + delta)` into each slot. `D = None` means
667/// the diagonal weight is zero, leaving just `sign · delta`.
668fn fill_diag(dst: &mut [Number], d: Option<&dyn Vector>, delta: Number, sign: Number) {
669    match d {
670        None => {
671            for v in dst.iter_mut() {
672                *v = sign * delta;
673            }
674        }
675        Some(d) => {
676            let xs = flat_read(d);
677            debug_assert_eq!(xs.len(), dst.len());
678            for (out, &x) in dst.iter_mut().zip(xs.iter()) {
679                *out = sign * (x + delta);
680            }
681        }
682    }
683}
684
685fn copy_vec(src: &dyn Vector, dst: &mut [Number]) {
686    let xs = flat_read(src);
687    debug_assert_eq!(xs.len(), dst.len());
688    dst.copy_from_slice(&xs);
689}
690
691fn write_vec(dst: &mut dyn Vector, src: &[Number]) {
692    flat_write(dst, src);
693}
694
695#[cfg(test)]
696mod tests {
697    use super::*;
698    use pounce_common::types::{Index, Number};
699    use pounce_linalg::dense_vector::DenseVectorSpace;
700    use pounce_linalg::triplet::{GenTMatrixSpace, SymTMatrixSpace};
701    use pounce_linsol::sparse_sym_iface::SparseSymLinearSolverInterface;
702    use pounce_linsol::EMatrixFormat;
703
704    /// Mock backend: dense LU via tiny Gauss elimination. Used to drive
705    /// `StdAugSystemSolver` end-to-end without an MA57 dependency.
706    struct DenseMock {
707        dim: Index,
708        nz: Index,
709        a: Vec<Number>,
710        last_factor: Vec<Number>, // dense `dim*dim`, lower triangle source
711        neg_evals: Index,
712    }
713
714    impl DenseMock {
715        fn new() -> Self {
716            Self {
717                dim: 0,
718                nz: 0,
719                a: Vec::new(),
720                last_factor: Vec::new(),
721                neg_evals: 0,
722            }
723        }
724    }
725
726    impl SparseSymLinearSolverInterface for DenseMock {
727        fn initialize_structure(
728            &mut self,
729            dim: Index,
730            nz: Index,
731            _ia: &[Index],
732            _ja: &[Index],
733        ) -> ESymSolverStatus {
734            self.dim = dim;
735            self.nz = nz;
736            self.a = vec![0.0; nz as usize];
737            ESymSolverStatus::Success
738        }
739        fn values_array_mut(&mut self) -> &mut [Number] {
740            &mut self.a
741        }
742        fn multi_solve(
743            &mut self,
744            new_matrix: bool,
745            ia: &[Index],
746            ja: &[Index],
747            nrhs: Index,
748            rhs_vals: &mut [Number],
749            _check: bool,
750            _nev: Index,
751        ) -> ESymSolverStatus {
752            let n = self.dim as usize;
753            if new_matrix {
754                // Densify the symmetric triplet into row-major full
755                // matrix for LU.
756                let mut dense = vec![0.0; n * n];
757                for k in 0..self.nz as usize {
758                    let i = (ia[k] - 1) as usize;
759                    let j = (ja[k] - 1) as usize;
760                    dense[i * n + j] += self.a[k];
761                    if i != j {
762                        dense[j * n + i] += self.a[k];
763                    }
764                }
765                self.last_factor = dense;
766            }
767            // Gauss-eliminate (no pivoting) per column for each rhs.
768            for col in 0..nrhs as usize {
769                let mut a = self.last_factor.clone();
770                let b = &mut rhs_vals[col * n..col * n + n];
771                let mut neg = 0_i32;
772                for k in 0..n {
773                    // Find pivot row by max-abs in col k below k.
774                    let mut piv = k;
775                    let mut piv_abs = a[k * n + k].abs();
776                    for r in (k + 1)..n {
777                        let av = a[r * n + k].abs();
778                        if av > piv_abs {
779                            piv_abs = av;
780                            piv = r;
781                        }
782                    }
783                    if piv != k {
784                        for c in 0..n {
785                            a.swap(k * n + c, piv * n + c);
786                        }
787                        b.swap(k, piv);
788                    }
789                    let p = a[k * n + k];
790                    if p.abs() < 1e-14 {
791                        return ESymSolverStatus::Singular;
792                    }
793                    if p < 0.0 {
794                        neg += 1;
795                    }
796                    for r in (k + 1)..n {
797                        let f = a[r * n + k] / p;
798                        for c in k..n {
799                            a[r * n + c] -= f * a[k * n + c];
800                        }
801                        b[r] -= f * b[k];
802                    }
803                }
804                // Back-substitute.
805                for k in (0..n).rev() {
806                    let mut s = b[k];
807                    for c in (k + 1)..n {
808                        s -= a[k * n + c] * b[c];
809                    }
810                    b[k] = s / a[k * n + k];
811                }
812                self.neg_evals = neg;
813            }
814            ESymSolverStatus::Success
815        }
816        fn number_of_neg_evals(&self) -> Index {
817            self.neg_evals
818        }
819        fn increase_quality(&mut self) -> bool {
820            false
821        }
822        fn provides_inertia(&self) -> bool {
823            true
824        }
825        fn matrix_format(&self) -> EMatrixFormat {
826            EMatrixFormat::TripletFormat
827        }
828    }
829
830    /// Hand-built tiny KKT system (n_x=2, n_s=1, n_c=1, n_d=1):
831    ///
832    /// ```text
833    ///   W = diag(2, 3)        D_x = (0, 0)   δ_x = 0
834    ///   D_s = (1)             δ_s = 0
835    ///   J_c = [1  1]          D_c = (0)      δ_c = 0
836    ///   J_d = [1  0]          D_d = (0)      δ_d = 0
837    /// ```
838    ///
839    /// Pick rhs so that the solution is `(dx, ds, dyc, dyd) = (1, 1, 1,
840    /// 1, 1)` — five unknowns. Derive rhs from `K · sol`.
841    #[test]
842    fn solves_5x5_kkt_through_dense_mock() {
843        // ---- W ----
844        let w_space = SymTMatrixSpace::new(2, vec![1, 2], vec![1, 2]);
845        let mut w = SymTMatrix::new(w_space);
846        w.set_values(&[2.0, 3.0]);
847
848        // ---- J_c (1×2 dense in triplet) ----
849        let jc_space = GenTMatrixSpace::new(1, 2, vec![1, 1], vec![1, 2]);
850        let mut j_c = GenTMatrix::new(jc_space);
851        j_c.set_values(&[1.0, 1.0]);
852
853        // ---- J_d (1×2) ----
854        let jd_space = GenTMatrixSpace::new(1, 2, vec![1], vec![1]);
855        let mut j_d = GenTMatrix::new(jd_space);
856        j_d.set_values(&[1.0]);
857
858        // ---- D_s = 1 (homogeneous) ----
859        let s_space = DenseVectorSpace::new(1);
860        let mut d_s = s_space.make_new_dense();
861        d_s.set_values(&[1.0]);
862
863        // RHS slots — match Ipopt convention: (rhs_x, rhs_s, rhs_c, rhs_d).
864        // Compute K · (1,1,1,1,1):
865        //   row x1: 2·1 + 0 + 1·1 + 1·1 = 4
866        //   row x2: 3·1 + 0 + 1·1 + 0·1 = 4
867        //   row s:  1·1 + 0·1·yd + (-1)·1 = 0     (D_s + δ_s) - 1
868        //   row c:  1·1 + 1·1     = 2
869        //   row d:  1·1 - 1·1     = 0
870        let xs = DenseVectorSpace::new(2);
871        let mut rx = xs.make_new_dense();
872        rx.set_values(&[4.0, 4.0]);
873        let mut rs = s_space.make_new_dense();
874        rs.set_values(&[0.0]);
875        let cs = DenseVectorSpace::new(1);
876        let mut rc = cs.make_new_dense();
877        rc.set_values(&[2.0]);
878        let ds_space = DenseVectorSpace::new(1);
879        let mut rd = ds_space.make_new_dense();
880        rd.set_values(&[0.0]);
881
882        let mut sx = xs.make_new_dense();
883        let mut ss = s_space.make_new_dense();
884        let mut sc = cs.make_new_dense();
885        let mut sd = ds_space.make_new_dense();
886
887        let linsol = TSymLinearSolver::new(Box::new(DenseMock::new()), None, false);
888        let mut solver = StdAugSystemSolver::new(linsol);
889
890        let coeffs = AugSysCoeffs {
891            w: Some(&w),
892            w_factor: 1.0,
893            d_x: None,
894            delta_x: 0.0,
895            d_s: Some(&d_s),
896            delta_s: 0.0,
897            j_c: &j_c,
898            d_c: None,
899            delta_c: 0.0,
900            j_d: &j_d,
901            d_d: None,
902            delta_d: 0.0,
903        };
904        let rhs = AugSysRhs {
905            rhs_x: &rx,
906            rhs_s: &rs,
907            rhs_c: &rc,
908            rhs_d: &rd,
909        };
910        let mut sol = AugSysSol {
911            sol_x: &mut sx,
912            sol_s: &mut ss,
913            sol_c: &mut sc,
914            sol_d: &mut sd,
915        };
916        let status = solver.solve(&coeffs, &rhs, &mut sol, false, 0);
917        assert_eq!(status, ESymSolverStatus::Success);
918
919        for v in sx.values() {
920            assert!((v - 1.0).abs() < 1e-10, "sol_x = {v}");
921        }
922        for v in ss.values() {
923            assert!((v - 1.0).abs() < 1e-10, "sol_s = {v}");
924        }
925        for v in sc.values() {
926            assert!((v - 1.0).abs() < 1e-10, "sol_c = {v}");
927        }
928        for v in sd.values() {
929            assert!((v - 1.0).abs() < 1e-10, "sol_d = {v}");
930        }
931    }
932}