Skip to main content

trueno_solve/
cholesky.rs

1//! Cholesky factorization for symmetric positive definite matrices.
2//!
3//! # Contract
4//!
5//! A = L L^T where L is lower triangular with positive diagonal.
6//!
7//! ## Proof obligations
8//! - L is lower triangular
9//! - ||A - L L^T||_F / ||A||_F < n · u
10//! - L diagonal entries are positive
11
12use crate::error::SolverError;
13
14/// Cholesky factorization result.
15#[derive(Debug)]
16pub struct CholeskyFactorization {
17    /// Dimension.
18    pub n: usize,
19    /// Lower triangular factor (row-major, only lower triangle is valid).
20    pub l: Vec<f32>,
21}
22
23/// Cholesky factorization: A = L L^T.
24///
25/// # Errors
26///
27/// Returns `NotPositiveDefinite` if a non-positive pivot is encountered.
28pub fn cholesky(a: &[f32], n: usize) -> Result<CholeskyFactorization, SolverError> {
29    if a.len() != n * n {
30        return Err(SolverError::NotSquare {
31            rows: n,
32            cols: a.len() / n.max(1),
33        });
34    }
35
36    let mut l = vec![0.0f32; n * n];
37
38    for j in 0..n {
39        // Diagonal entry
40        let mut sum = f64::from(a[j * n + j]);
41        for k in 0..j {
42            let ljk = f64::from(l[j * n + k]);
43            sum -= ljk * ljk;
44        }
45
46        if sum <= 0.0 {
47            return Err(SolverError::NotPositiveDefinite(j));
48        }
49        l[j * n + j] = sum.sqrt() as f32;
50
51        let ljj_inv = 1.0 / f64::from(l[j * n + j]);
52
53        // Below-diagonal entries
54        for i in (j + 1)..n {
55            let mut sum = f64::from(a[i * n + j]);
56            for k in 0..j {
57                sum -= f64::from(l[i * n + k]) * f64::from(l[j * n + k]);
58            }
59            l[i * n + j] = (sum * ljj_inv) as f32;
60        }
61    }
62
63    Ok(CholeskyFactorization { n, l })
64}
65
66impl CholeskyFactorization {
67    /// Solve Ax = b using Cholesky factors: L L^T x = b.
68    ///
69    /// # Errors
70    ///
71    /// Returns error on dimension mismatch.
72    pub fn solve(&self, b: &[f32]) -> Result<Vec<f32>, SolverError> {
73        if b.len() != self.n {
74            return Err(SolverError::DimensionMismatch {
75                matrix_n: self.n,
76                rhs_len: b.len(),
77            });
78        }
79
80        let n = self.n;
81
82        // Forward substitution: L y = b
83        let mut y = b.to_vec();
84        for i in 0..n {
85            let mut sum = f64::from(y[i]);
86            for j in 0..i {
87                sum -= f64::from(self.l[i * n + j]) * f64::from(y[j]);
88            }
89            y[i] = (sum / f64::from(self.l[i * n + i])) as f32;
90        }
91
92        // Back substitution: L^T x = y
93        let mut x = y;
94        for i in (0..n).rev() {
95            let mut sum = f64::from(x[i]);
96            for j in (i + 1)..n {
97                sum -= f64::from(self.l[j * n + i]) * f64::from(x[j]);
98            }
99            x[i] = (sum / f64::from(self.l[i * n + i])) as f32;
100        }
101
102        Ok(x)
103    }
104}