1use crate::error::SolverError;
13
14#[derive(Debug)]
16pub struct CholeskyFactorization {
17 pub n: usize,
19 pub l: Vec<f32>,
21}
22
23pub fn cholesky(a: &[f32], n: usize) -> Result<CholeskyFactorization, SolverError> {
29 if a.len() != n * n {
30 return Err(SolverError::NotSquare {
31 rows: n,
32 cols: a.len() / n.max(1),
33 });
34 }
35
36 let mut l = vec![0.0f32; n * n];
37
38 for j in 0..n {
39 let mut sum = f64::from(a[j * n + j]);
41 for k in 0..j {
42 let ljk = f64::from(l[j * n + k]);
43 sum -= ljk * ljk;
44 }
45
46 if sum <= 0.0 {
47 return Err(SolverError::NotPositiveDefinite(j));
48 }
49 l[j * n + j] = sum.sqrt() as f32;
50
51 let ljj_inv = 1.0 / f64::from(l[j * n + j]);
52
53 for i in (j + 1)..n {
55 let mut sum = f64::from(a[i * n + j]);
56 for k in 0..j {
57 sum -= f64::from(l[i * n + k]) * f64::from(l[j * n + k]);
58 }
59 l[i * n + j] = (sum * ljj_inv) as f32;
60 }
61 }
62
63 Ok(CholeskyFactorization { n, l })
64}
65
66impl CholeskyFactorization {
67 pub fn solve(&self, b: &[f32]) -> Result<Vec<f32>, SolverError> {
73 if b.len() != self.n {
74 return Err(SolverError::DimensionMismatch {
75 matrix_n: self.n,
76 rhs_len: b.len(),
77 });
78 }
79
80 let n = self.n;
81
82 let mut y = b.to_vec();
84 for i in 0..n {
85 let mut sum = f64::from(y[i]);
86 for j in 0..i {
87 sum -= f64::from(self.l[i * n + j]) * f64::from(y[j]);
88 }
89 y[i] = (sum / f64::from(self.l[i * n + i])) as f32;
90 }
91
92 let mut x = y;
94 for i in (0..n).rev() {
95 let mut sum = f64::from(x[i]);
96 for j in (i + 1)..n {
97 sum -= f64::from(self.l[j * n + i]) * f64::from(x[j]);
98 }
99 x[i] = (sum / f64::from(self.l[i * n + i])) as f32;
100 }
101
102 Ok(x)
103 }
104}