Skip to main content

pounce_linsol/
t_sym_solver.rs

1//! Algorithm-side wrapper that drives a sparse symmetric backend.
2//!
3//! Port of `Algorithm/LinearSolvers/IpTSymLinearSolver.{hpp,cpp}` from
4//! Ipopt 3.14.x. This is the layer between an algorithm's "give me a
5//! `SymMatrix` plus RHS, return the solution" expectation and the
6//! per-backend [`SparseSymLinearSolverInterface`] contract.
7//!
8//! Responsibilities, mirroring upstream:
9//!
10//! * Marshal a triplet `(airn, ajcn, vals)` matrix into the layout the
11//!   backend declared via [`SparseSymLinearSolverInterface::matrix_format`].
12//! * If a [`TSymScalingMethod`] is configured, compute symmetric
13//!   scaling factors `s` once per refactor and apply
14//!   `A' = diag(s) A diag(s)` / `b' = diag(s) b` / `x = diag(s) x'`.
15//! * Drive the backend's `CALL_AGAIN` retry loop (MA57 grow case).
16//! * Forward `IncreaseQuality` to the backend, with the upstream
17//!   "switch on linear-system scaling on demand" optimization.
18//!
19//! Tag-based change detection (`SymMatrix::HasChanged`) is *not* part
20//! of this wrapper; callers in the Phase-6 KKT layer pass `new_matrix`
21//! explicitly. That keeps Phase-4 self-contained while leaving the
22//! upstream semantics intact.
23
24use crate::scaling::TSymScalingMethod;
25use crate::sparse_sym_iface::{EMatrixFormat, FactorPattern, SparseSymLinearSolverInterface};
26use crate::status::ESymSolverStatus;
27use crate::sym_solver::SymLinearSolver;
28use pounce_common::types::{Index, Number};
29use pounce_linalg::triplet_convert::{TriFull, TripletToCsrConverter};
30
31/// Driver wrapping a [`SparseSymLinearSolverInterface`] (and optionally
32/// a [`TSymScalingMethod`]).
33pub struct TSymLinearSolver {
34    backend: Box<dyn SparseSymLinearSolverInterface>,
35    scaling_method: Option<Box<dyn TSymScalingMethod>>,
36    matrix_format: EMatrixFormat,
37    converter: Option<TripletToCsrConverter>,
38
39    /// `true` once [`Self::initialize_structure`] has succeeded.
40    initialized: bool,
41    /// `true` once the row/column index arrays are populated (= ditto
42    /// in this port; see upstream's `have_structure_` for warm-start
43    /// semantics).
44    have_structure: bool,
45    /// `true` if the wrapper should currently apply scaling.
46    use_scaling: bool,
47    /// Set by [`Self::increase_quality`] when scaling-on-demand fires;
48    /// triggers a one-shot scaling-factor recompute on the next solve.
49    just_switched_on_scaling: bool,
50    /// Mirrors `linear_scaling_on_demand`. `true` keeps scaling off
51    /// until `increase_quality` switches it on; `false` scales every
52    /// refactor.
53    linear_scaling_on_demand: bool,
54
55    dim: Index,
56    nonzeros_triplet: Index,
57    nonzeros_compressed: Index,
58
59    /// 1-based row indices, one per triplet entry.
60    airn: Vec<Index>,
61    /// 1-based column indices, one per triplet entry.
62    ajcn: Vec<Index>,
63    /// Per-row symmetric scaling factors (length `dim`). Empty unless
64    /// a scaling method is configured.
65    scaling_factors: Vec<Number>,
66}
67
68impl std::fmt::Debug for TSymLinearSolver {
69    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
70        f.debug_struct("TSymLinearSolver")
71            .field("matrix_format", &self.matrix_format)
72            .field("dim", &self.dim)
73            .field("nonzeros_triplet", &self.nonzeros_triplet)
74            .field("nonzeros_compressed", &self.nonzeros_compressed)
75            .field("use_scaling", &self.use_scaling)
76            .field("initialized", &self.initialized)
77            .finish_non_exhaustive()
78    }
79}
80
81impl TSymLinearSolver {
82    /// Build a driver around `backend`. Pass `Some(scaling)` to enable
83    /// symmetric scaling. `linear_scaling_on_demand=true` matches
84    /// upstream's default and keeps scaling off until
85    /// [`Self::increase_quality`] turns it on.
86    pub fn new(
87        backend: Box<dyn SparseSymLinearSolverInterface>,
88        scaling_method: Option<Box<dyn TSymScalingMethod>>,
89        linear_scaling_on_demand: bool,
90    ) -> Self {
91        let matrix_format = backend.matrix_format();
92        let converter = match matrix_format {
93            EMatrixFormat::TripletFormat => None,
94            EMatrixFormat::CsrFormat0Offset => {
95                Some(TripletToCsrConverter::new(0, TriFull::Triangular))
96            }
97            EMatrixFormat::CsrFormat1Offset => {
98                Some(TripletToCsrConverter::new(1, TriFull::Triangular))
99            }
100            EMatrixFormat::CsrFullFormat0Offset => {
101                Some(TripletToCsrConverter::new(0, TriFull::Full))
102            }
103            EMatrixFormat::CsrFullFormat1Offset => {
104                Some(TripletToCsrConverter::new(1, TriFull::Full))
105            }
106        };
107        let use_scaling = scaling_method.is_some() && !linear_scaling_on_demand;
108        Self {
109            backend,
110            scaling_method,
111            matrix_format,
112            converter,
113            initialized: false,
114            have_structure: false,
115            use_scaling,
116            just_switched_on_scaling: false,
117            linear_scaling_on_demand,
118            dim: 0,
119            nonzeros_triplet: 0,
120            nonzeros_compressed: 0,
121            airn: Vec::new(),
122            ajcn: Vec::new(),
123            scaling_factors: Vec::new(),
124        }
125    }
126
127    /// Pin the triplet sparsity pattern. Must be called once before
128    /// the first [`Self::multi_solve`]. `airn` / `ajcn` are 1-based.
129    /// Mirrors the bulk of `TSymLinearSolver::InitializeStructure`.
130    pub fn initialize_structure(
131        &mut self,
132        dim: Index,
133        airn: &[Index],
134        ajcn: &[Index],
135    ) -> ESymSolverStatus {
136        assert_eq!(airn.len(), ajcn.len());
137        let nz = airn.len() as Index;
138        self.dim = dim;
139        self.nonzeros_triplet = nz;
140        self.airn = airn.to_vec();
141        self.ajcn = ajcn.to_vec();
142
143        let (ia, ja, nonzeros) = match self.converter.as_mut() {
144            None => (&self.airn[..], &self.ajcn[..], self.nonzeros_triplet),
145            Some(conv) => {
146                let nonzeros_compressed = conv.initialize(self.dim, &self.airn, &self.ajcn);
147                self.nonzeros_compressed = nonzeros_compressed;
148                (conv.ia(), conv.ja(), nonzeros_compressed)
149            }
150        };
151        let status = self.backend.initialize_structure(dim, nonzeros, ia, ja);
152        if status != ESymSolverStatus::Success {
153            return status;
154        }
155        if self.scaling_method.is_some() {
156            self.scaling_factors = vec![0.0; dim as usize];
157        }
158        self.have_structure = true;
159        self.initialized = true;
160        status
161    }
162
163    /// Solve `A x = b` (or multiple RHS).
164    ///
165    /// `vals` is the new triplet-format value array (length
166    /// `nonzeros_triplet`). `new_matrix=true` requests a refactor
167    /// (and a fresh scaling-factor computation if scaling is on);
168    /// `new_matrix=false` reuses the existing factor and just runs
169    /// back-substitution.
170    ///
171    /// `rhs_vals` packs `nrhs` columns, each length `dim`, in
172    /// column-major layout. Solutions overwrite `rhs_vals`.
173    #[allow(clippy::too_many_arguments)]
174    pub fn multi_solve(
175        &mut self,
176        vals: &[Number],
177        new_matrix: bool,
178        nrhs: Index,
179        rhs_vals: &mut [Number],
180        check_neg_evals: bool,
181        number_of_neg_evals: Index,
182    ) -> ESymSolverStatus {
183        debug_assert!(self.initialized);
184        debug_assert_eq!(vals.len(), self.nonzeros_triplet as usize);
185        debug_assert_eq!(rhs_vals.len(), (self.dim * nrhs) as usize);
186
187        // One-shot KKT dump for backend-comparison testing. Triggered
188        // when POUNCE_DBG_KKT_DUMP is set to a file path; writes one
189        // binary record (dim, nnz, nrhs, ia[], ja[], vals[], rhs[]) on
190        // the Nth multi_solve call (N = POUNCE_DBG_KKT_DUMP_SKIP, default 0),
191        // then disables itself.
192        //
193        // DEPRECATED: superseded by the unified `--dump kkt:<spec>` CLI
194        // surface (see `pounce_common::diagnostics`). Kept for the
195        // existing offline FERAL/MA57/LAPACK binary-comparison tool;
196        // the env var emits a one-shot warning on first observation.
197        {
198            use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
199            static CALL_COUNT: AtomicUsize = AtomicUsize::new(0);
200            static WARNED: AtomicBool = AtomicBool::new(false);
201            let n_call = CALL_COUNT.fetch_add(1, Ordering::SeqCst);
202            let skip: usize = std::env::var("POUNCE_DBG_KKT_DUMP_SKIP")
203                .ok()
204                .and_then(|s| s.parse().ok())
205                .unwrap_or(0);
206            if n_call < skip {
207                // not yet
208            } else if let Ok(path) = std::env::var("POUNCE_DBG_KKT_DUMP") {
209                if !WARNED.swap(true, Ordering::SeqCst) {
210                    tracing::warn!(
211                        target: "pounce::linsol",
212                        "POUNCE_DBG_KKT_DUMP is deprecated; prefer `--dump kkt:<iter-spec>` (see pounce --help)"
213                    );
214                }
215                use std::io::Write;
216                if let Ok(mut f) = std::fs::File::create(&path) {
217                    let dim = self.dim as u64;
218                    let nnz = self.nonzeros_triplet as u64;
219                    let nrhs64 = nrhs as u64;
220                    let _ = f.write_all(&dim.to_le_bytes());
221                    let _ = f.write_all(&nnz.to_le_bytes());
222                    let _ = f.write_all(&nrhs64.to_le_bytes());
223                    for &i in &self.airn {
224                        let _ = f.write_all(&(i as i64).to_le_bytes());
225                    }
226                    for &j in &self.ajcn {
227                        let _ = f.write_all(&(j as i64).to_le_bytes());
228                    }
229                    for &v in vals {
230                        let _ = f.write_all(&v.to_le_bytes());
231                    }
232                    for &v in &*rhs_vals {
233                        let _ = f.write_all(&v.to_le_bytes());
234                    }
235                    let _ = f.flush();
236                }
237                // SAFETY: removing an env var is safe in single-threaded
238                // setup; this dump fires from the main IPM thread.
239                unsafe {
240                    std::env::remove_var("POUNCE_DBG_KKT_DUMP");
241                }
242            }
243        }
244
245        // Push values + (optional) scaling into the backend.
246        let mut new_matrix = new_matrix;
247        if new_matrix || self.just_switched_on_scaling {
248            self.give_matrix_to_solver(true, vals);
249            new_matrix = true;
250        }
251
252        // Apply scaling to RHS columns (multiply by `s_i` per row).
253        if self.use_scaling {
254            for irhs in 0..nrhs as usize {
255                let base = irhs * self.dim as usize;
256                for i in 0..self.dim as usize {
257                    rhs_vals[base + i] *= self.scaling_factors[i];
258                }
259            }
260        }
261
262        // Backend solve, with `CALL_AGAIN` retry (MA57 grow path).
263        // Pre-resolve the index arrays into local pointers so we can
264        // hand them to the backend without re-borrowing `self`.
265        let status = loop {
266            let (ia_ptr, ia_len, ja_ptr, ja_len) = match self.converter.as_ref() {
267                None => (
268                    self.airn.as_ptr(),
269                    self.airn.len(),
270                    self.ajcn.as_ptr(),
271                    self.ajcn.len(),
272                ),
273                Some(c) => (c.ia().as_ptr(), c.ia().len(), c.ja().as_ptr(), c.ja().len()),
274            };
275            // SAFETY: the slices live in `self.airn/ajcn` or in the
276            // converter, both owned by `self`; the pointers are valid
277            // for the duration of this `multi_solve` call.
278            let (ia, ja) = unsafe {
279                (
280                    std::slice::from_raw_parts(ia_ptr, ia_len),
281                    std::slice::from_raw_parts(ja_ptr, ja_len),
282                )
283            };
284            let s = self.backend.multi_solve(
285                new_matrix,
286                ia,
287                ja,
288                nrhs,
289                rhs_vals,
290                check_neg_evals,
291                number_of_neg_evals,
292            );
293            if s == ESymSolverStatus::CallAgain {
294                self.give_matrix_to_solver(false, vals);
295                continue;
296            }
297            break s;
298        };
299
300        if status == ESymSolverStatus::Success && self.use_scaling {
301            // Solution comes back in scaled coordinates `x' = diag(s)
302            // x`; restore by another diag(s) multiply (since the
303            // scaled system is `(D A D)(D^-1 x) = D b` and we passed
304            // `D b`, the backend returns `D^-1 x`, hence multiply by
305            // `D` once more — see cpp:286-289).
306            for irhs in 0..nrhs as usize {
307                let base = irhs * self.dim as usize;
308                for i in 0..self.dim as usize {
309                    rhs_vals[base + i] *= self.scaling_factors[i];
310                }
311            }
312        }
313
314        status
315    }
316
317    /// Push `vals` (triplet-format) into the backend in the right
318    /// layout, optionally computing scaling factors and applying the
319    /// symmetric scale. Mirrors `TSymLinearSolver::GiveMatrixToSolver`.
320    fn give_matrix_to_solver(&mut self, new_matrix: bool, vals: &[Number]) {
321        // For triplet-format backends we write directly into the
322        // backend's array; for CSR backends we marshal via a temporary
323        // and call `convert_values`.
324        if self.matrix_format == EMatrixFormat::TripletFormat && !self.use_scaling {
325            let pa = self.backend.values_array_mut();
326            pa[..self.nonzeros_triplet as usize]
327                .copy_from_slice(&vals[..self.nonzeros_triplet as usize]);
328            return;
329        }
330
331        // Stage values in a local buffer so we can scale before
332        // shipping to the backend.
333        let mut atriplet: Vec<Number> = vals[..self.nonzeros_triplet as usize].to_vec();
334
335        if self.use_scaling {
336            if new_matrix || self.just_switched_on_scaling {
337                // `use_scaling` implies the scaling method is set
338                // (checked at construction time).
339                let Some(method) = self.scaling_method.as_mut() else {
340                    unreachable!("use_scaling without a scaling method")
341                };
342                let ok = method.compute_sym_t_scaling_factors(
343                    self.dim,
344                    self.nonzeros_triplet,
345                    &self.airn,
346                    &self.ajcn,
347                    &atriplet,
348                    &mut self.scaling_factors,
349                );
350                assert!(ok, "scaling method failed");
351                self.just_switched_on_scaling = false;
352            }
353            for (i, a) in atriplet
354                .iter_mut()
355                .enumerate()
356                .take(self.nonzeros_triplet as usize)
357            {
358                let r = (self.airn[i] - 1) as usize;
359                let c = (self.ajcn[i] - 1) as usize;
360                *a *= self.scaling_factors[r] * self.scaling_factors[c];
361            }
362        }
363
364        if self.matrix_format == EMatrixFormat::TripletFormat {
365            let pa = self.backend.values_array_mut();
366            pa[..self.nonzeros_triplet as usize].copy_from_slice(&atriplet);
367        } else {
368            let Some(conv) = self.converter.as_ref() else {
369                unreachable!("non-triplet matrix_format requires a converter");
370            };
371            let pa = self.backend.values_array_mut();
372            conv.convert_values(&atriplet, &mut pa[..self.nonzeros_compressed as usize]);
373        }
374    }
375
376    /// Pass-through to the backend's diagnostic factor-pattern
377    /// accessor. Returns `None` when the backend does not expose its
378    /// factor data (e.g. MA57). Consumed by the `--dump kkt:*+L` path.
379    pub fn factor_pattern(&self, want_values: bool) -> Option<FactorPattern> {
380        self.backend.factor_pattern(want_values)
381    }
382}
383
384impl SymLinearSolver for TSymLinearSolver {
385    fn number_of_neg_evals(&self) -> Index {
386        self.backend.number_of_neg_evals()
387    }
388
389    /// Mirrors upstream's `IncreaseQuality`: switching scaling on at
390    /// the wrapper level (`linear_scaling_on_demand=true` path) takes
391    /// precedence over asking the backend for tighter pivoting.
392    fn increase_quality(&mut self) -> bool {
393        if self.scaling_method.is_some() && !self.use_scaling && self.linear_scaling_on_demand {
394            self.use_scaling = true;
395            self.just_switched_on_scaling = true;
396            return true;
397        }
398        self.backend.increase_quality()
399    }
400
401    fn provides_inertia(&self) -> bool {
402        self.backend.provides_inertia()
403    }
404}
405
406#[cfg(test)]
407mod tests {
408    use super::*;
409    use crate::scaling::IdentityScalingMethod;
410
411    /// Mock triplet-format backend that exposes the values array,
412    /// records the most-recent solve call, and returns a hand-rolled
413    /// solution. Lets us exercise the wrapper without an FFI dep.
414    #[derive(Default)]
415    struct MockBackend {
416        dim: Index,
417        nz: Index,
418        a: Vec<Number>,
419        last_solve_was_new_matrix: bool,
420        last_solve_was_scaled_a: Option<Vec<Number>>,
421        canned_solution: Vec<Number>,
422        neg_evals: Index,
423        increase_quality_calls: u32,
424        max_increase_quality_calls: u32,
425    }
426
427    impl SparseSymLinearSolverInterface for MockBackend {
428        fn initialize_structure(
429            &mut self,
430            dim: Index,
431            nz: Index,
432            _ia: &[Index],
433            _ja: &[Index],
434        ) -> ESymSolverStatus {
435            self.dim = dim;
436            self.nz = nz;
437            self.a = vec![0.0; nz as usize];
438            ESymSolverStatus::Success
439        }
440        fn values_array_mut(&mut self) -> &mut [Number] {
441            &mut self.a
442        }
443        fn multi_solve(
444            &mut self,
445            new_matrix: bool,
446            _ia: &[Index],
447            _ja: &[Index],
448            nrhs: Index,
449            rhs_vals: &mut [Number],
450            _check: bool,
451            _nev: Index,
452        ) -> ESymSolverStatus {
453            self.last_solve_was_new_matrix = new_matrix;
454            self.last_solve_was_scaled_a = Some(self.a.clone());
455            assert_eq!(rhs_vals.len(), (self.dim * nrhs) as usize);
456            for irhs in 0..nrhs as usize {
457                let base = irhs * self.dim as usize;
458                rhs_vals[base..base + self.dim as usize].copy_from_slice(&self.canned_solution);
459            }
460            ESymSolverStatus::Success
461        }
462        fn number_of_neg_evals(&self) -> Index {
463            self.neg_evals
464        }
465        fn increase_quality(&mut self) -> bool {
466            self.increase_quality_calls += 1;
467            self.increase_quality_calls <= self.max_increase_quality_calls
468        }
469        fn provides_inertia(&self) -> bool {
470            true
471        }
472        fn matrix_format(&self) -> EMatrixFormat {
473            EMatrixFormat::TripletFormat
474        }
475    }
476
477    fn make_2x2_indef_pattern() -> ([Index; 3], [Index; 3]) {
478        ([1, 2, 2], [1, 1, 2])
479    }
480
481    #[test]
482    fn unscaled_triplet_solve_passes_values_through() {
483        let backend = MockBackend {
484            canned_solution: vec![10.0, 20.0],
485            ..Default::default()
486        };
487        let mut solver = TSymLinearSolver::new(Box::new(backend), None, false);
488        let (irn, jcn) = make_2x2_indef_pattern();
489        assert_eq!(
490            solver.initialize_structure(2, &irn, &jcn),
491            ESymSolverStatus::Success
492        );
493
494        let vals = [2.0, 1.0, 3.0];
495        let mut rhs = [3.0, 4.0];
496        assert_eq!(
497            solver.multi_solve(&vals, true, 1, &mut rhs, false, 0),
498            ESymSolverStatus::Success
499        );
500        // Mock writes its canned solution.
501        assert_eq!(rhs, [10.0, 20.0]);
502        assert!(solver.provides_inertia());
503    }
504
505    #[test]
506    fn identity_scaling_does_not_change_values() {
507        let backend = MockBackend {
508            canned_solution: vec![1.0, 1.0],
509            ..Default::default()
510        };
511        // linear_scaling_on_demand=false → scaling active from the
512        // first solve.
513        let mut solver = TSymLinearSolver::new(
514            Box::new(backend),
515            Some(Box::new(IdentityScalingMethod)),
516            false,
517        );
518        let (irn, jcn) = make_2x2_indef_pattern();
519        solver.initialize_structure(2, &irn, &jcn);
520
521        let vals = [2.0, 1.0, 3.0];
522        let mut rhs = [4.0, 5.0];
523        assert_eq!(
524            solver.multi_solve(&vals, true, 1, &mut rhs, false, 0),
525            ESymSolverStatus::Success
526        );
527        // Identity scaling: backend should have received the original
528        // value array unchanged, and the canned solution survives the
529        // unscale step (multiplied by 1.0 twice).
530        assert_eq!(rhs, [1.0, 1.0]);
531    }
532
533    #[test]
534    fn nontrivial_scaling_premultiplies_matrix_and_postmultiplies_solution() {
535        // Scaling method that returns s = (2, 3). After scaling, the
536        // backend should see (D A D) where D = diag(2,3); solving with
537        // RHS (D b) returns (D^-1 x), and the wrapper unscales by D
538        // once more to recover x.
539        struct DiagTwoThree;
540        impl TSymScalingMethod for DiagTwoThree {
541            fn compute_sym_t_scaling_factors(
542                &mut self,
543                _n: Index,
544                _nnz: Index,
545                _airn: &[Index],
546                _ajcn: &[Index],
547                _a: &[Number],
548                scaling_factors: &mut [Number],
549            ) -> bool {
550                scaling_factors[0] = 2.0;
551                scaling_factors[1] = 3.0;
552                true
553            }
554        }
555
556        let backend = MockBackend {
557            // Wrapper passes scaled RHS = (2*4, 3*5) = (8, 15).
558            // Mock returns `canned_solution` ignoring the input;
559            // wrapper then unscales: x = D · canned = (2 * c0, 3 * c1).
560            canned_solution: vec![7.0, 11.0],
561            ..Default::default()
562        };
563        let mut solver =
564            TSymLinearSolver::new(Box::new(backend), Some(Box::new(DiagTwoThree)), false);
565        let (irn, jcn) = make_2x2_indef_pattern();
566        solver.initialize_structure(2, &irn, &jcn);
567
568        let vals = [2.0, 1.0, 3.0];
569        let mut rhs = [4.0, 5.0];
570        assert_eq!(
571            solver.multi_solve(&vals, true, 1, &mut rhs, false, 0),
572            ESymSolverStatus::Success
573        );
574        assert_eq!(rhs, [2.0 * 7.0, 3.0 * 11.0]);
575    }
576
577    #[test]
578    fn increase_quality_switches_on_scaling_first() {
579        let backend = MockBackend {
580            canned_solution: vec![0.0, 0.0],
581            max_increase_quality_calls: 5,
582            ..Default::default()
583        };
584        let mut solver = TSymLinearSolver::new(
585            Box::new(backend),
586            Some(Box::new(IdentityScalingMethod)),
587            true, // on demand
588        );
589        // First IncreaseQuality flips on scaling, does NOT touch the
590        // backend.
591        assert!(solver.increase_quality());
592        // Second IncreaseQuality goes to the backend.
593        assert!(solver.increase_quality());
594    }
595
596    #[test]
597    fn increase_quality_without_scaling_goes_straight_to_backend() {
598        let backend = MockBackend {
599            max_increase_quality_calls: 1,
600            ..Default::default()
601        };
602        let mut solver = TSymLinearSolver::new(Box::new(backend), None, false);
603        assert!(solver.increase_quality());
604        // Backend caps at 1; second call returns false.
605        assert!(!solver.increase_quality());
606    }
607}