1#[cfg(feature = "serde-serialize")]
2use serde::{Deserialize, Serialize};
3
4use num::Zero;
5use num_complex::Complex;
6
7use na::allocator::Allocator;
8use na::dimension::Dim;
9use na::storage::Storage;
10use na::{DefaultAllocator, Matrix, OMatrix, Scalar};
11
12use lapack;
13
14#[cfg_attr(feature = "serde-serialize", derive(Serialize, Deserialize))]
16#[cfg_attr(
17 feature = "serde-serialize",
18 serde(bound(serialize = "DefaultAllocator: Allocator<D>,
19 OMatrix<T, D, D>: Serialize"))
20)]
21#[cfg_attr(
22 feature = "serde-serialize",
23 serde(bound(deserialize = "DefaultAllocator: Allocator<D>,
24 OMatrix<T, D, D>: Deserialize<'de>"))
25)]
26#[derive(Clone, Debug)]
27pub struct Cholesky<T: Scalar, D: Dim>
28where
29 DefaultAllocator: Allocator<D, D>,
30{
31 l: OMatrix<T, D, D>,
32}
33
34impl<T: Scalar + Copy, D: Dim> Copy for Cholesky<T, D>
35where
36 DefaultAllocator: Allocator<D, D>,
37 OMatrix<T, D, D>: Copy,
38{
39}
40
41impl<T: CholeskyScalar + Zero, D: Dim> Cholesky<T, D>
42where
43 DefaultAllocator: Allocator<D, D>,
44{
45 #[inline]
50 pub fn new(mut m: OMatrix<T, D, D>) -> Option<Self> {
51 assert!(
53 m.is_square(),
54 "Unable to compute the Cholesky decomposition of a non-square matrix."
55 );
56
57 let uplo = b'L';
58 let dim = m.nrows() as i32;
59 let mut info = 0;
60
61 T::xpotrf(uplo, dim, m.as_mut_slice(), dim, &mut info);
62 lapack_check!(info);
63
64 Some(Self { l: m })
65 }
66
67 pub fn unpack(mut self) -> OMatrix<T, D, D> {
69 self.l.fill_upper_triangle(Zero::zero(), 1);
70 self.l
71 }
72
73 pub fn unpack_dirty(self) -> OMatrix<T, D, D> {
79 self.l
80 }
81
82 #[must_use]
84 pub fn l(&self) -> OMatrix<T, D, D> {
85 let mut res = self.l.clone();
86 res.fill_upper_triangle(Zero::zero(), 1);
87 res
88 }
89
90 #[must_use]
96 pub fn l_dirty(&self) -> &OMatrix<T, D, D> {
97 &self.l
98 }
99
100 pub fn solve<R2: Dim, C2: Dim, S2>(
103 &self,
104 b: &Matrix<T, R2, C2, S2>,
105 ) -> Option<OMatrix<T, R2, C2>>
106 where
107 S2: Storage<T, R2, C2>,
108 DefaultAllocator: Allocator<R2, C2>,
109 {
110 let mut res = b.clone_owned();
111 if self.solve_mut(&mut res) {
112 Some(res)
113 } else {
114 None
115 }
116 }
117
118 pub fn solve_mut<R2: Dim, C2: Dim>(&self, b: &mut OMatrix<T, R2, C2>) -> bool
121 where
122 DefaultAllocator: Allocator<R2, C2>,
123 {
124 let dim = self.l.nrows();
125
126 assert!(
127 b.nrows() == dim,
128 "The number of rows of `b` must be equal to the dimension of the matrix `a`."
129 );
130
131 let nrhs = b.ncols() as i32;
132 let lda = dim as i32;
133 let ldb = dim as i32;
134 let mut info = 0;
135
136 T::xpotrs(
137 b'L',
138 dim as i32,
139 nrhs,
140 self.l.as_slice(),
141 lda,
142 b.as_mut_slice(),
143 ldb,
144 &mut info,
145 );
146 lapack_test!(info)
147 }
148
149 pub fn inverse(mut self) -> Option<OMatrix<T, D, D>> {
151 let dim = self.l.nrows();
152 let mut info = 0;
153
154 T::xpotri(
155 b'L',
156 dim as i32,
157 self.l.as_mut_slice(),
158 dim as i32,
159 &mut info,
160 );
161 lapack_check!(info);
162
163 for i in 0..dim {
165 for j in i + 1..dim {
166 unsafe { *self.l.get_unchecked_mut((i, j)) = *self.l.get_unchecked((j, i)) };
167 }
168 }
169
170 Some(self.l)
171 }
172}
173
174pub trait CholeskyScalar: Scalar + Copy {
182 #[allow(missing_docs)]
183 fn xpotrf(uplo: u8, n: i32, a: &mut [Self], lda: i32, info: &mut i32);
184 #[allow(missing_docs)]
185 fn xpotrs(
186 uplo: u8,
187 n: i32,
188 nrhs: i32,
189 a: &[Self],
190 lda: i32,
191 b: &mut [Self],
192 ldb: i32,
193 info: &mut i32,
194 );
195 #[allow(missing_docs)]
196 fn xpotri(uplo: u8, n: i32, a: &mut [Self], lda: i32, info: &mut i32);
197}
198
199macro_rules! cholesky_scalar_impl(
200 ($N: ty, $xpotrf: path, $xpotrs: path, $xpotri: path) => (
201 impl CholeskyScalar for $N {
202 #[inline]
203 fn xpotrf(uplo: u8, n: i32, a: &mut [Self], lda: i32, info: &mut i32) {
204 unsafe { $xpotrf(uplo, n, a, lda, info) }
205 }
206
207 #[inline]
208 fn xpotrs(uplo: u8, n: i32, nrhs: i32, a: &[Self], lda: i32,
209 b: &mut [Self], ldb: i32, info: &mut i32) {
210 unsafe { $xpotrs(uplo, n, nrhs, a, lda, b, ldb, info) }
211 }
212
213 #[inline]
214 fn xpotri(uplo: u8, n: i32, a: &mut [Self], lda: i32, info: &mut i32) {
215 unsafe { $xpotri(uplo, n, a, lda, info) }
216 }
217 }
218 )
219);
220
221cholesky_scalar_impl!(f32, lapack::spotrf, lapack::spotrs, lapack::spotri);
222cholesky_scalar_impl!(f64, lapack::dpotrf, lapack::dpotrs, lapack::dpotri);
223cholesky_scalar_impl!(Complex<f32>, lapack::cpotrf, lapack::cpotrs, lapack::cpotri);
224cholesky_scalar_impl!(Complex<f64>, lapack::zpotrf, lapack::zpotrs, lapack::zpotri);