Skip to main content

pounce_linsol/
factorization.rs

1//! Public factor-once / solve-many handle.
2//!
3//! [`Factorization`] is the user-facing value type for "I have a sparse
4//! symmetric matrix, factor it once, then solve against the factor
5//! repeatedly." It wraps an arbitrary [`SparseSymLinearSolverInterface`]
6//! backend (feral, MA57, etc.) and the [`TSymLinearSolver`] driver that
7//! handles triplet→CSR conversion and the `CallAgain` retry loop for
8//! backends that may grow their factor arrays.
9//!
10//! Two operations preserve the cached factor:
11//!
12//! * [`Factorization::solve`] — back-substitute against the current
13//!   factor. Cheap; reuses the LDLᵀ factors.
14//! * [`Factorization::refactor`] — supply new numeric values for the
15//!   same sparsity pattern; backend reuses its symbolic factor (AMD
16//!   ordering, elimination tree, pattern cache) and only redoes the
17//!   numeric work.
18//!
19//! See the `examples/shift_invert.rs` example for a worked use of
20//! factor-once / many-RHS for a shift-invert eigenvalue probe.
21//!
22//! # Example
23//!
24//! ```
25//! use pounce_linsol::{Factorization, FactorizationError};
26//! # // The example uses a dummy in-tree backend; real callers supply
27//! # // feral or MA57.
28//! # struct Dummy; // placeholder; full example needs a backend
29//! ```
30
31use crate::error::FactorizationError;
32use crate::sparse_sym_iface::SparseSymLinearSolverInterface;
33use crate::t_sym_solver::TSymLinearSolver;
34use pounce_common::types::{Index, Number};
35
36/// Value-typed handle holding a sparse symmetric factorization.
37///
38/// Construction (via [`Factorization::new`]) performs the symbolic +
39/// numeric factor; subsequent [`Factorization::solve`] calls are pure
40/// back-substitution. [`Factorization::refactor`] replaces the numeric
41/// values without redoing the symbolic work.
42///
43/// The matrix is supplied in **triplet (COO) format with 1-based
44/// indices over the lower triangle** — the universal denominator the
45/// trait expects. Backends that prefer CSR are fed via the standard
46/// `TripletToCsrConverter` inside the wrapper.
47pub struct Factorization {
48    inner: TSymLinearSolver,
49    dim: Index,
50    nnz: Index,
51    values: Vec<Number>,
52    inertia_known: bool,
53}
54
55impl std::fmt::Debug for Factorization {
56    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
57        f.debug_struct("Factorization")
58            .field("dim", &self.dim)
59            .field("nnz", &self.nnz)
60            .field("inertia_known", &self.inertia_known)
61            .finish_non_exhaustive()
62    }
63}
64
65impl Factorization {
66    /// Factor a new matrix. Pattern (`airn`, `ajcn`) and `values` are
67    /// the lower-triangle triplet of `A`, 1-based indices, length
68    /// `nnz` each. `backend` is any implementor (feral, MA57, …).
69    ///
70    /// Performs both the symbolic and numeric factorization. Subsequent
71    /// [`Self::solve`] calls are back-substitution only;
72    /// [`Self::refactor`] redoes the numeric work but reuses the
73    /// symbolic factor.
74    ///
75    /// # Errors
76    ///
77    /// * [`FactorizationError::Singular`] — the supplied matrix is
78    ///   numerically singular.
79    /// * [`FactorizationError::FatalError`] — unrecoverable backend
80    ///   error.
81    ///
82    /// # Panics
83    ///
84    /// Panics if `airn.len() != ajcn.len()` or if `values.len() !=
85    /// airn.len()`.
86    pub fn new(
87        dim: Index,
88        airn: Vec<Index>,
89        ajcn: Vec<Index>,
90        values: Vec<Number>,
91        backend: Box<dyn SparseSymLinearSolverInterface>,
92    ) -> Result<Self, FactorizationError> {
93        assert_eq!(
94            airn.len(),
95            ajcn.len(),
96            "airn and ajcn must have same length"
97        );
98        assert_eq!(values.len(), airn.len(), "values must match nnz");
99        let nnz = airn.len() as Index;
100        let mut inner = TSymLinearSolver::new(backend, None, false);
101        FactorizationError::from_status(inner.initialize_structure(dim, &airn, &ajcn))?;
102
103        let mut me = Self {
104            inner,
105            dim,
106            nnz,
107            values,
108            inertia_known: false,
109        };
110
111        // Initial factor. We issue a no-op back-solve (nrhs=1 with a
112        // zero RHS we then discard) because the trait does not expose
113        // a factor-only entry point — every multi_solve also runs back
114        // substitution. The cost is one triangular solve per
115        // construction, which is negligible relative to the factor.
116        me.do_factor()?;
117        Ok(me)
118    }
119
120    /// Back-substitute against the cached factor. `rhs` packs `nrhs`
121    /// columns (each length `dim`) in column-major layout; solutions
122    /// overwrite `rhs` in place.
123    ///
124    /// # Errors
125    ///
126    /// [`FactorizationError::FatalError`] — backend solve failed.
127    ///
128    /// # Panics
129    ///
130    /// Panics if `rhs.len() != dim * nrhs`.
131    pub fn solve(&mut self, rhs: &mut [Number], nrhs: usize) -> Result<(), FactorizationError> {
132        assert_eq!(
133            rhs.len(),
134            self.dim as usize * nrhs,
135            "rhs length must equal dim * nrhs"
136        );
137        let status = self.inner.multi_solve(
138            &self.values,
139            false, // new_matrix = false: pure back-substitution
140            nrhs as Index,
141            rhs,
142            false,
143            0,
144        );
145        FactorizationError::from_status(status)
146    }
147
148    /// Convenience for the common `nrhs=1` case. Identical to
149    /// `solve(rhs, 1)`.
150    pub fn solve_one(&mut self, rhs: &mut [Number]) -> Result<(), FactorizationError> {
151        self.solve(rhs, 1)
152    }
153
154    /// Replace the numeric values and refactor. Pattern is unchanged;
155    /// the backend reuses its symbolic factor / AMD ordering.
156    ///
157    /// # Errors
158    ///
159    /// Same as [`Self::new`].
160    ///
161    /// # Panics
162    ///
163    /// Panics if `new_values.len() != nnz`.
164    pub fn refactor(&mut self, new_values: &[Number]) -> Result<(), FactorizationError> {
165        assert_eq!(
166            new_values.len(),
167            self.nnz as usize,
168            "new_values length must equal nnz",
169        );
170        self.values.copy_from_slice(new_values);
171        self.inertia_known = false;
172        self.do_factor()
173    }
174
175    /// Number of negative eigenvalues from the most recent factor, if
176    /// the backend reports inertia. `None` otherwise.
177    pub fn number_of_neg_evals(&self) -> Option<Index> {
178        use crate::sym_solver::SymLinearSolver;
179        if self.inertia_known && self.inner.provides_inertia() {
180            Some(self.inner.number_of_neg_evals())
181        } else {
182            None
183        }
184    }
185
186    /// Dimension `n` of the factored `n × n` matrix.
187    pub fn dim(&self) -> Index {
188        self.dim
189    }
190
191    /// Number of nonzeros in the triplet pattern.
192    pub fn nnz(&self) -> Index {
193        self.nnz
194    }
195
196    /// Internal helper: issue a factor (and discard the back-solve).
197    fn do_factor(&mut self) -> Result<(), FactorizationError> {
198        let mut dummy_rhs = vec![0.0; self.dim as usize];
199        let status = self.inner.multi_solve(
200            &self.values,
201            true, // new_matrix = true: factor now
202            1,
203            &mut dummy_rhs,
204            false,
205            0,
206        );
207        FactorizationError::from_status(status)?;
208        self.inertia_known = true;
209        Ok(())
210    }
211}
212
213// Compile-time assertion that a Factorization is Send — backends are
214// `Box<dyn SparseSymLinearSolverInterface>` which doesn't currently
215// require Send, so this is intentionally not asserted at the type
216// level. Users threading Factorizations across threads must ensure
217// their backend is Send (feral and MA57 both are in practice).
218
219#[cfg(test)]
220mod tests {
221    use super::*;
222    use crate::sparse_sym_iface::EMatrixFormat;
223    use crate::status::ESymSolverStatus;
224
225    /// Minimal in-test backend that solves a dense symmetric system via
226    /// LU. Lets us exercise the Factorization API without depending on
227    /// feral or MA57 in this crate's test suite.
228    struct DenseLuBackend {
229        dim: usize,
230        nnz: usize,
231        rows: Vec<Index>, // 1-based
232        cols: Vec<Index>, // 1-based
233        values: Vec<Number>,
234        // Cached LU factor of the dense matrix; rebuilt on each factor().
235        factor: Option<DenseLu>,
236    }
237
238    struct DenseLu {
239        a: Vec<Vec<f64>>, // L\U combined (Doolittle)
240        perm: Vec<usize>,
241        neg_evals: Index,
242    }
243
244    impl DenseLuBackend {
245        fn new() -> Self {
246            Self {
247                dim: 0,
248                nnz: 0,
249                rows: Vec::new(),
250                cols: Vec::new(),
251                values: Vec::new(),
252                factor: None,
253            }
254        }
255
256        fn assemble_dense(&self) -> Vec<Vec<f64>> {
257            let n = self.dim;
258            let mut a = vec![vec![0.0; n]; n];
259            for k in 0..self.nnz {
260                let i = (self.rows[k] - 1) as usize;
261                let j = (self.cols[k] - 1) as usize;
262                a[i][j] += self.values[k];
263                if i != j {
264                    a[j][i] += self.values[k];
265                }
266            }
267            a
268        }
269
270        fn factor_dense(&mut self) -> ESymSolverStatus {
271            let n = self.dim;
272            let mut a = self.assemble_dense();
273            let mut perm: Vec<usize> = (0..n).collect();
274            // Partial-pivoted LU.
275            for k in 0..n {
276                // Pivot.
277                let mut p = k;
278                let mut maxv = a[perm[k]][k].abs();
279                for i in (k + 1)..n {
280                    let v = a[perm[i]][k].abs();
281                    if v > maxv {
282                        maxv = v;
283                        p = i;
284                    }
285                }
286                if maxv < 1e-300 {
287                    return ESymSolverStatus::Singular;
288                }
289                perm.swap(k, p);
290                let pk = perm[k];
291                for &pi in &perm[(k + 1)..n] {
292                    let factor = a[pi][k] / a[pk][k];
293                    a[pi][k] = factor;
294                    #[allow(clippy::needless_range_loop)]
295                    for j in (k + 1)..n {
296                        a[pi][j] -= factor * a[pk][j];
297                    }
298                }
299            }
300            // Count negative diagonal entries of U as a stand-in for
301            // inertia (correct for symmetric matrices with no pivoting,
302            // which this backend does pivot — so this is only an
303            // approximation, used to exercise the inertia code path in
304            // tests).
305            let mut neg = 0;
306            for k in 0..n {
307                if a[perm[k]][k] < 0.0 {
308                    neg += 1;
309                }
310            }
311            self.factor = Some(DenseLu {
312                a,
313                perm,
314                neg_evals: neg as Index,
315            });
316            ESymSolverStatus::Success
317        }
318
319        fn solve_one(&self, b: &mut [f64]) {
320            let factor = self.factor.as_ref().unwrap();
321            let n = self.dim;
322            // Permute.
323            let mut x: Vec<f64> = factor.perm.iter().map(|&p| b[p]).collect();
324            // Forward substitution (unit-lower).
325            for i in 0..n {
326                let pi = factor.perm[i];
327                for j in 0..i {
328                    x[i] -= factor.a[pi][j] * x[j];
329                }
330            }
331            // Back substitution (upper).
332            for i in (0..n).rev() {
333                let pi = factor.perm[i];
334                for j in (i + 1)..n {
335                    x[i] -= factor.a[pi][j] * x[j];
336                }
337                x[i] /= factor.a[pi][i];
338            }
339            b.copy_from_slice(&x);
340        }
341    }
342
343    impl SparseSymLinearSolverInterface for DenseLuBackend {
344        fn initialize_structure(
345            &mut self,
346            dim: Index,
347            nonzeros: Index,
348            ia: &[Index],
349            ja: &[Index],
350        ) -> ESymSolverStatus {
351            self.dim = dim as usize;
352            self.nnz = nonzeros as usize;
353            self.rows = ia.to_vec();
354            self.cols = ja.to_vec();
355            self.values = vec![0.0; self.nnz];
356            ESymSolverStatus::Success
357        }
358
359        fn values_array_mut(&mut self) -> &mut [Number] {
360            &mut self.values
361        }
362
363        fn multi_solve(
364            &mut self,
365            new_matrix: bool,
366            _ia: &[Index],
367            _ja: &[Index],
368            nrhs: Index,
369            rhs_vals: &mut [Number],
370            check_neg_evals: bool,
371            number_of_neg_evals: Index,
372        ) -> ESymSolverStatus {
373            if new_matrix {
374                let s = self.factor_dense();
375                if s != ESymSolverStatus::Success {
376                    return s;
377                }
378                if check_neg_evals {
379                    let actual = self.factor.as_ref().unwrap().neg_evals;
380                    if actual != number_of_neg_evals {
381                        return ESymSolverStatus::WrongInertia;
382                    }
383                }
384            }
385            let n = self.dim;
386            for k in 0..nrhs as usize {
387                let base = k * n;
388                self.solve_one(&mut rhs_vals[base..base + n]);
389            }
390            ESymSolverStatus::Success
391        }
392
393        fn number_of_neg_evals(&self) -> Index {
394            self.factor.as_ref().map(|f| f.neg_evals).unwrap_or(0)
395        }
396
397        fn increase_quality(&mut self) -> bool {
398            false
399        }
400
401        fn provides_inertia(&self) -> bool {
402            true
403        }
404
405        fn matrix_format(&self) -> EMatrixFormat {
406            EMatrixFormat::TripletFormat
407        }
408    }
409
410    /// SPD 2x2: `[[2,1],[1,3]]`. Lower-triangle 1-based triplets.
411    /// Solving against (3, 4) gives (1, 1).
412    #[test]
413    fn factors_spd_2x2_and_solves_one_rhs() {
414        let airn = vec![1, 2, 2];
415        let ajcn = vec![1, 1, 2];
416        let values = vec![2.0, 1.0, 3.0];
417        let mut f =
418            Factorization::new(2, airn, ajcn, values, Box::new(DenseLuBackend::new())).unwrap();
419        let mut rhs = vec![3.0, 4.0];
420        f.solve_one(&mut rhs).unwrap();
421        assert!((rhs[0] - 1.0).abs() < 1e-12);
422        assert!((rhs[1] - 1.0).abs() < 1e-12);
423    }
424
425    /// Same SPD 2x2, multiple RHS packed; results match single-RHS
426    /// solves done individually.
427    #[test]
428    fn packed_multi_rhs_matches_one_at_a_time() {
429        let airn = vec![1, 2, 2];
430        let ajcn = vec![1, 1, 2];
431        let values = vec![2.0, 1.0, 3.0];
432        let backend1 = Box::new(DenseLuBackend::new());
433        let backend2 = Box::new(DenseLuBackend::new());
434        let mut f1 =
435            Factorization::new(2, airn.clone(), ajcn.clone(), values.clone(), backend1).unwrap();
436        let mut f2 = Factorization::new(2, airn, ajcn, values, backend2).unwrap();
437
438        // Packed 3-RHS solve.
439        let mut packed = vec![
440            3.0, 4.0, // col 0 → expect (1, 1)
441            5.0, 5.0, // col 1
442            2.0, 6.0, // col 2
443        ];
444        f1.solve(&mut packed, 3).unwrap();
445
446        // One-at-a-time for the same RHS columns.
447        let mut col0 = vec![3.0, 4.0];
448        let mut col1 = vec![5.0, 5.0];
449        let mut col2 = vec![2.0, 6.0];
450        f2.solve_one(&mut col0).unwrap();
451        f2.solve_one(&mut col1).unwrap();
452        f2.solve_one(&mut col2).unwrap();
453
454        for (i, &v) in col0.iter().enumerate() {
455            assert!((packed[i] - v).abs() < 1e-12, "col0 mismatch at {i}");
456        }
457        for (i, &v) in col1.iter().enumerate() {
458            assert!((packed[2 + i] - v).abs() < 1e-12, "col1 mismatch at {i}");
459        }
460        for (i, &v) in col2.iter().enumerate() {
461            assert!((packed[4 + i] - v).abs() < 1e-12, "col2 mismatch at {i}");
462        }
463    }
464
465    /// Refactor with perturbed values; residual against the perturbed
466    /// system is small.
467    #[test]
468    fn refactor_yields_correct_solution_for_new_values() {
469        let airn = vec![1, 2, 2];
470        let ajcn = vec![1, 1, 2];
471        let mut f = Factorization::new(
472            2,
473            airn,
474            ajcn,
475            vec![2.0, 1.0, 3.0],
476            Box::new(DenseLuBackend::new()),
477        )
478        .unwrap();
479
480        // Perturb to `[[4, 1], [1, 5]]`.
481        f.refactor(&[4.0, 1.0, 5.0]).unwrap();
482        let mut rhs = vec![5.0, 6.0]; // expect ~ (19/19, 19/19) = (1, 1)
483        f.solve_one(&mut rhs).unwrap();
484        // Check residual: A x - b where A = [[4,1],[1,5]], b = (5, 6).
485        let r0 = 4.0 * rhs[0] + rhs[1] - 5.0;
486        let r1 = rhs[0] + 5.0 * rhs[1] - 6.0;
487        assert!(r0.abs() < 1e-10);
488        assert!(r1.abs() < 1e-10);
489    }
490
491    /// Singular matrix → Singular error.
492    #[test]
493    fn singular_matrix_returns_singular_error() {
494        // `[[0, 1], [1, 0]]` is symmetric indefinite but the LU with
495        // partial pivoting on it succeeds (it pivots the off-diagonal
496        // up). Use a genuinely singular matrix instead: `[[1,1],[1,1]]`.
497        let airn = vec![1, 2, 2];
498        let ajcn = vec![1, 1, 2];
499        let err = Factorization::new(
500            2,
501            airn,
502            ajcn,
503            vec![1.0, 1.0, 1.0],
504            Box::new(DenseLuBackend::new()),
505        )
506        .unwrap_err();
507        assert_eq!(err, FactorizationError::Singular);
508    }
509
510    /// `solve_one` and `solve(.., 1)` produce identical results.
511    #[test]
512    fn solve_one_matches_solve_with_nrhs_one() {
513        let airn = vec![1, 2, 2];
514        let ajcn = vec![1, 1, 2];
515        let values = vec![2.0, 1.0, 3.0];
516        let mut f1 = Factorization::new(
517            2,
518            airn.clone(),
519            ajcn.clone(),
520            values.clone(),
521            Box::new(DenseLuBackend::new()),
522        )
523        .unwrap();
524        let mut f2 =
525            Factorization::new(2, airn, ajcn, values, Box::new(DenseLuBackend::new())).unwrap();
526
527        let mut rhs1 = vec![3.0, 4.0];
528        let mut rhs2 = vec![3.0, 4.0];
529        f1.solve_one(&mut rhs1).unwrap();
530        f2.solve(&mut rhs2, 1).unwrap();
531        assert_eq!(rhs1, rhs2);
532    }
533
534    /// Inertia is reported after construction; backend says it
535    /// provides inertia.
536    #[test]
537    fn inertia_is_reported_when_backend_provides_it() {
538        let airn = vec![1, 2, 2];
539        let ajcn = vec![1, 1, 2];
540        let f = Factorization::new(
541            2,
542            airn,
543            ajcn,
544            vec![2.0, 1.0, 3.0], // SPD, so 0 negative eigenvalues
545            Box::new(DenseLuBackend::new()),
546        )
547        .unwrap();
548        assert_eq!(f.number_of_neg_evals(), Some(0));
549        assert_eq!(f.dim(), 2);
550        assert_eq!(f.nnz(), 3);
551    }
552}