rlst/traits/linalg/decompositions.rs
1//! Traits for matrix decompositions and solving linear systems.
2
3use num::Zero;
4
5use itertools::Itertools;
6
7use crate::{
8 base_types::{RlstResult, UpLo},
9 dense::{
10 array::DynArray,
11 linalg::lapack::{
12 eigenvalue_decomposition::EigMode,
13 lu::LuDecomposition,
14 pseudo_inverse::PInv,
15 qr::{EnablePivoting, QrDecomposition},
16 singular_value_decomposition::SvdMode,
17 symmeig::SymmEigMode,
18 },
19 },
20 traits::{base_operations::EvaluateArray, rlst_num::RlstScalar},
21};
22
23use super::{base::Gemm, lapack::Lapack};
24
25/// Compute the matrix inverse.
26pub trait Inverse {
27 /// The item type of the inverse.
28 type Output;
29
30 /// Compute the inverse of a matrix.
31 fn inverse(&self) -> RlstResult<Self::Output>;
32}
33
34/// Compute the LU decomposition of a matrix.
35///
36/// The LU decomposition of an `m times n` matrix `A` is defined as `A = P * L * U`, where:
37/// - `A` is the original matrix,
38/// - `P` is a permutation matrix of dimension `m x m`,
39/// - `L` is a lower triangular matrix of dimension `m x k` with unit diagonal,
40/// - `U` is an upper triangular matrix of dimension `k x n`.
41///
42/// Here, `k = min(m, n)`.
43pub trait Lu {
44 /// Item type of the LU decomposition.
45 type Item: Lapack;
46 /// Compute the LU decomposition of a matrix.
47 fn lu(&self) -> RlstResult<LuDecomposition<Self::Item>>;
48}
49
50/// Compute the QR decomposition of a matrix.
51pub trait Qr {
52 /// Item type of the QR decomposition.
53 type Item: Lapack;
54
55 /// Compute the QR decomposition of a matrix.
56 fn qr(&self, pivoting: EnablePivoting) -> RlstResult<QrDecomposition<Self::Item>>;
57}
58
59/// Compute the symmetric eigenvalue decomposition of a matrix.
60pub trait SymmEig {
61 /// Item type of the symmetric eigenvalue decomposition.
62 type Item: Lapack;
63
64 /// Compute the eigenvalues of a real symmetric or complex Hermitian matrix.
65 fn eigenvaluesh(&self) -> RlstResult<DynArray<<Self::Item as RlstScalar>::Real, 1>> {
66 Ok(self.eigh(UpLo::Upper, SymmEigMode::EigenvaluesOnly)?.0)
67 }
68
69 /// Compute the symmetric eigenvalue decomposition of a matrix.
70 #[allow(clippy::type_complexity)]
71 fn eigh(
72 &self,
73 uplo: UpLo,
74 mode: SymmEigMode,
75 ) -> RlstResult<(
76 DynArray<<Self::Item as RlstScalar>::Real, 1>,
77 Option<DynArray<Self::Item, 2>>,
78 )>;
79}
80
81/// Compute the eigenvalue decomposition of a matrix.
82pub trait EigenvalueDecomposition {
83 /// The item type of the matrix.
84 type Item: Lapack;
85
86 /// Return the eigenvalues of the matrix.
87 fn eigenvalues(&self) -> RlstResult<DynArray<<Self::Item as RlstScalar>::Complex, 1>>;
88
89 /// Compute the Schur decomposition of the matrix.
90 /// Returns a tuple containing:
91 /// - A block-upper triangular matrix `T`. The diagonal blocks are 1x1 or 2x2.
92 ///
93 /// and encode the eigenvalues of the matrix.
94 /// - A unitary matrix `Z` such that `A = Z * T * Z^H`, where `Z^H` is the conjugate transpose
95 #[allow(clippy::type_complexity)]
96 fn schur(&self) -> RlstResult<(DynArray<Self::Item, 2>, DynArray<Self::Item, 2>)>;
97
98 /// Compute the eigenvalues and eigenvectors of the matrix.
99 ///
100 /// The function returns a tuple `(lam, v, w)` containing:
101 /// - A vector `lam` of eigenvalues.
102 /// - An optional matrix `v` of right eigenvectors.
103 /// - An optional matrix `w` of left eigenvectors.
104 #[allow(clippy::type_complexity)]
105 fn eig(
106 &self,
107 mode: EigMode,
108 ) -> RlstResult<(
109 DynArray<<Self::Item as RlstScalar>::Complex, 1>,
110 Option<DynArray<<Self::Item as RlstScalar>::Complex, 2>>,
111 Option<DynArray<<Self::Item as RlstScalar>::Complex, 2>>,
112 )>;
113}
114
115/// Compute the singular value decomposition of a matrix.
116pub trait SingularValueDecomposition {
117 /// The item type of the matrix.
118 type Item: Lapack + Gemm;
119
120 /// Compute the singular values of a matrix.
121 fn singular_values(&self) -> RlstResult<DynArray<<Self::Item as RlstScalar>::Real, 1>>;
122
123 /// Compute the singular value decomposition of a matrix.
124 ///
125 /// The function returns a tuple containing:
126 /// - A vector of singular values.
127 /// - A matrix `U` containing the left singular vectors.
128 /// - A matrix `Vh` containing the right singular vectors as rows.
129 #[allow(clippy::type_complexity)]
130 fn svd(
131 &self,
132 mode: SvdMode,
133 ) -> RlstResult<(
134 DynArray<<Self::Item as RlstScalar>::Real, 1>,
135 DynArray<Self::Item, 2>,
136 DynArray<Self::Item, 2>,
137 )>;
138
139 /// Compute the truncated singular value decomposition of a matrix.
140 ///
141 /// **Arguments:**
142 /// - `max_singular_values`: Maximum number of singular values to compute. If `None`, all
143 /// singular values are computed.
144 /// - `tol`: Relative tolerance for truncation. Singular values smaller or equal to `tol *
145 /// s[0]`, where `s[0]` is the largest singular value, will be discarded. Zero singular values
146 /// are always discarded.
147 ///
148 /// Returns a tuple containing:
149 /// - A vector of singular values
150 /// - A matrix `U` containing the left singular vectors.
151 /// - A matrix `Vh` containing the right singular vectors as rows.
152 #[allow(clippy::type_complexity)]
153 fn svd_truncated(
154 &self,
155 max_singular_values: Option<usize>,
156 tol: Option<<Self::Item as RlstScalar>::Real>,
157 ) -> RlstResult<(
158 DynArray<<Self::Item as RlstScalar>::Real, 1>,
159 DynArray<Self::Item, 2>,
160 DynArray<Self::Item, 2>,
161 )> {
162 let (s, u, vh) = self.svd(SvdMode::Compact)?;
163
164 let nvalues = std::cmp::min(
165 match max_singular_values {
166 Some(n) => n,
167 None => s.len(),
168 },
169 s.len(),
170 );
171
172 let tol = match tol {
173 Some(t) => t,
174 None => <<Self::Item as RlstScalar>::Real as Zero>::zero(),
175 };
176
177 let count = match s
178 .iter_value()
179 .take(nvalues)
180 .find_position(|&elem| elem <= tol * s[[0]])
181 {
182 Some((index, _)) => index,
183 None => nvalues,
184 };
185
186 let s = s.into_subview([0], [count]).eval();
187 let u = u.r().into_subview([0, 0], [u.shape()[0], count]).eval();
188 let vh = vh.r().into_subview([0, 0], [count, vh.shape()[1]]).eval();
189
190 Ok((s, u, vh))
191 }
192
193 /// Compute the pseudo-inverse of a matrix.
194 ///
195 /// If an integer for `max_singular_values` is provided no more than that number of
196 /// singular values is contained in the pseudo-inverse. The optional parameter `tol` specifies a relative
197 /// cut-off tolerance for the smallest singular value contained in the pseudo-inverse. The parameter `max_singular_values`
198 /// takes precedence in the sense that if it is specified the number of singular values can never exceed the given value
199 /// irrespective of `tol`.
200 fn pseudo_inverse(
201 &self,
202 max_singular_values: Option<usize>,
203 tol: Option<<Self::Item as RlstScalar>::Real>,
204 ) -> RlstResult<PInv<Self::Item>>
205 where
206 <Self::Item as RlstScalar>::Real: Into<Self::Item>,
207 {
208 let (s, u, vh) = self.svd_truncated(max_singular_values, tol)?;
209
210 Ok(PInv::new(
211 s.into_type().eval(),
212 u.conj().transpose().eval(),
213 vh.conj().transpose().eval(),
214 ))
215 }
216}
217
218/// Generic trait for solving square or rectangular linear systems.
219pub trait Solve<Rhs> {
220 /// The output type of the solver.
221 type Output;
222
223 /// Solve the linear system `Ax = b` for `x`.
224 // If `A` is not square, the system is solved in the least-squares sense.
225 fn solve(&self, rhs: &Rhs) -> RlstResult<Self::Output>;
226}
227
228/// Cholesky decomposition for positive definite matrices.
229pub trait Cholesky {
230 /// Item type of the array.
231 type Item;
232
233 /// Compute the Cholesky decomposition of a positive definite matrix.
234 ///
235 /// **Arguments:**
236 /// - `uplo`: Specifies whether the upper or lower triangular part of the matrix is stored.
237 fn cholesky(&self, uplo: UpLo) -> RlstResult<DynArray<Self::Item, 2>>;
238}
239
240/// Cholesky solver for positive definite systems.
241pub trait CholeskySolve<Rhs> {
242 /// The output type of the Cholesky solver.
243 type Output;
244
245 /// Solve a positive definite system of linear equations using Cholesky factorization.
246 ///
247 /// **Arguments:**
248 /// - `uplo`: Specifies whether the upper or lower triangular part of the matrix is stored.
249 fn cholesky_solve(&self, uplo: UpLo, rhs: &Rhs) -> RlstResult<Self::Output>;
250}
251
252/// Solve a triangular system of linear equations.
253pub trait SolveTriangular<Rhs> {
254 /// The output type of the triangular solver.
255 type Output;
256
257 /// Solve a triangular system of linear equations.
258 ///
259 /// **Arguments:**
260 /// - `uplo`: Specifies whether the upper or lower triangular part of the matrix is stored.
261 fn solve_triangular(&self, uplo: UpLo, rhs: &Rhs) -> RlstResult<Self::Output>;
262}