rlst/traits/linalg/
decompositions.rs

1//! Traits for matrix decompositions and solving linear systems.
2
3use num::Zero;
4
5use itertools::Itertools;
6
7use crate::{
8    base_types::{RlstResult, UpLo},
9    dense::{
10        array::DynArray,
11        linalg::lapack::{
12            eigenvalue_decomposition::EigMode,
13            lu::LuDecomposition,
14            pseudo_inverse::PInv,
15            qr::{EnablePivoting, QrDecomposition},
16            singular_value_decomposition::SvdMode,
17            symmeig::SymmEigMode,
18        },
19    },
20    traits::{base_operations::EvaluateArray, rlst_num::RlstScalar},
21};
22
23use super::{base::Gemm, lapack::Lapack};
24
25/// Compute the matrix inverse.
26pub trait Inverse {
27    /// The item type of the inverse.
28    type Output;
29
30    /// Compute the inverse of a matrix.
31    fn inverse(&self) -> RlstResult<Self::Output>;
32}
33
34/// Compute the LU decomposition of a matrix.
35///
36/// The LU decomposition of an `m times n` matrix `A` is defined as `A = P * L * U`, where:
37/// - `A` is the original matrix,
38/// - `P` is a permutation matrix of dimension `m x m`,
39/// - `L` is a lower triangular matrix of dimension `m x k` with unit diagonal,
40/// - `U` is an upper triangular matrix of dimension `k x n`.
41///
42/// Here, `k = min(m, n)`.
43pub trait Lu {
44    /// Item type of the LU decomposition.
45    type Item: Lapack;
46    /// Compute the LU decomposition of a matrix.
47    fn lu(&self) -> RlstResult<LuDecomposition<Self::Item>>;
48}
49
50/// Compute the QR decomposition of a matrix.
51pub trait Qr {
52    /// Item type of the QR decomposition.
53    type Item: Lapack;
54
55    /// Compute the QR decomposition of a matrix.
56    fn qr(&self, pivoting: EnablePivoting) -> RlstResult<QrDecomposition<Self::Item>>;
57}
58
59/// Compute the symmetric eigenvalue decomposition of a matrix.
60pub trait SymmEig {
61    /// Item type of the symmetric eigenvalue decomposition.
62    type Item: Lapack;
63
64    /// Compute the eigenvalues of a real symmetric or complex Hermitian matrix.
65    fn eigenvaluesh(&self) -> RlstResult<DynArray<<Self::Item as RlstScalar>::Real, 1>> {
66        Ok(self.eigh(UpLo::Upper, SymmEigMode::EigenvaluesOnly)?.0)
67    }
68
69    /// Compute the symmetric eigenvalue decomposition of a matrix.
70    #[allow(clippy::type_complexity)]
71    fn eigh(
72        &self,
73        uplo: UpLo,
74        mode: SymmEigMode,
75    ) -> RlstResult<(
76        DynArray<<Self::Item as RlstScalar>::Real, 1>,
77        Option<DynArray<Self::Item, 2>>,
78    )>;
79}
80
81/// Compute the eigenvalue decomposition of a matrix.
82pub trait EigenvalueDecomposition {
83    /// The item type of the matrix.
84    type Item: Lapack;
85
86    /// Return the eigenvalues of the matrix.
87    fn eigenvalues(&self) -> RlstResult<DynArray<<Self::Item as RlstScalar>::Complex, 1>>;
88
89    /// Compute the Schur decomposition of the matrix.
90    /// Returns a tuple containing:
91    /// - A block-upper triangular matrix `T`. The diagonal blocks are 1x1 or 2x2.
92    ///
93    /// and encode the eigenvalues of the matrix.
94    /// - A unitary matrix `Z` such that `A = Z * T * Z^H`, where `Z^H` is the conjugate transpose
95    #[allow(clippy::type_complexity)]
96    fn schur(&self) -> RlstResult<(DynArray<Self::Item, 2>, DynArray<Self::Item, 2>)>;
97
98    /// Compute the eigenvalues and eigenvectors of the matrix.
99    ///
100    /// The function returns a tuple `(lam, v, w)` containing:
101    /// - A vector `lam` of eigenvalues.
102    /// - An optional matrix `v` of right eigenvectors.
103    /// - An optional matrix `w` of left eigenvectors.
104    #[allow(clippy::type_complexity)]
105    fn eig(
106        &self,
107        mode: EigMode,
108    ) -> RlstResult<(
109        DynArray<<Self::Item as RlstScalar>::Complex, 1>,
110        Option<DynArray<<Self::Item as RlstScalar>::Complex, 2>>,
111        Option<DynArray<<Self::Item as RlstScalar>::Complex, 2>>,
112    )>;
113}
114
115/// Compute the singular value decomposition of a matrix.
116pub trait SingularValueDecomposition {
117    /// The item type of the matrix.
118    type Item: Lapack + Gemm;
119
120    /// Compute the singular values of a matrix.
121    fn singular_values(&self) -> RlstResult<DynArray<<Self::Item as RlstScalar>::Real, 1>>;
122
123    /// Compute the singular value decomposition of a matrix.
124    ///
125    /// The function returns a tuple containing:
126    /// - A vector of singular values.
127    /// - A matrix `U` containing the left singular vectors.
128    /// - A matrix `Vh` containing the right singular vectors as rows.
129    #[allow(clippy::type_complexity)]
130    fn svd(
131        &self,
132        mode: SvdMode,
133    ) -> RlstResult<(
134        DynArray<<Self::Item as RlstScalar>::Real, 1>,
135        DynArray<Self::Item, 2>,
136        DynArray<Self::Item, 2>,
137    )>;
138
139    /// Compute the truncated singular value decomposition of a matrix.
140    ///
141    /// **Arguments:**
142    /// - `max_singular_values`: Maximum number of singular values to compute. If `None`, all
143    ///   singular values are computed.
144    /// - `tol`: Relative tolerance for truncation. Singular values smaller or equal to `tol *
145    ///   s[0]`, where `s[0]` is the largest singular value, will be discarded. Zero singular values
146    ///   are always discarded.
147    ///
148    /// Returns a tuple containing:
149    /// - A vector of singular values
150    /// - A matrix `U` containing the left singular vectors.
151    /// - A matrix `Vh` containing the right singular vectors as rows.
152    #[allow(clippy::type_complexity)]
153    fn svd_truncated(
154        &self,
155        max_singular_values: Option<usize>,
156        tol: Option<<Self::Item as RlstScalar>::Real>,
157    ) -> RlstResult<(
158        DynArray<<Self::Item as RlstScalar>::Real, 1>,
159        DynArray<Self::Item, 2>,
160        DynArray<Self::Item, 2>,
161    )> {
162        let (s, u, vh) = self.svd(SvdMode::Compact)?;
163
164        let nvalues = std::cmp::min(
165            match max_singular_values {
166                Some(n) => n,
167                None => s.len(),
168            },
169            s.len(),
170        );
171
172        let tol = match tol {
173            Some(t) => t,
174            None => <<Self::Item as RlstScalar>::Real as Zero>::zero(),
175        };
176
177        let count = match s
178            .iter_value()
179            .take(nvalues)
180            .find_position(|&elem| elem <= tol * s[[0]])
181        {
182            Some((index, _)) => index,
183            None => nvalues,
184        };
185
186        let s = s.into_subview([0], [count]).eval();
187        let u = u.r().into_subview([0, 0], [u.shape()[0], count]).eval();
188        let vh = vh.r().into_subview([0, 0], [count, vh.shape()[1]]).eval();
189
190        Ok((s, u, vh))
191    }
192
193    /// Compute the pseudo-inverse of a matrix.
194    ///
195    /// If an integer for `max_singular_values` is provided no more than that number of
196    /// singular values is contained in the pseudo-inverse. The optional parameter `tol` specifies a relative
197    /// cut-off tolerance for the smallest singular value contained in the pseudo-inverse. The parameter `max_singular_values`
198    /// takes precedence in the sense that if it is specified the number of singular values can never exceed the given value
199    /// irrespective of `tol`.
200    fn pseudo_inverse(
201        &self,
202        max_singular_values: Option<usize>,
203        tol: Option<<Self::Item as RlstScalar>::Real>,
204    ) -> RlstResult<PInv<Self::Item>>
205    where
206        <Self::Item as RlstScalar>::Real: Into<Self::Item>,
207    {
208        let (s, u, vh) = self.svd_truncated(max_singular_values, tol)?;
209
210        Ok(PInv::new(
211            s.into_type().eval(),
212            u.conj().transpose().eval(),
213            vh.conj().transpose().eval(),
214        ))
215    }
216}
217
218/// Generic trait for solving square or rectangular linear systems.
219pub trait Solve<Rhs> {
220    /// The output type of the solver.
221    type Output;
222
223    /// Solve the linear system `Ax = b` for `x`.
224    // If `A` is not square, the system is solved in the least-squares sense.
225    fn solve(&self, rhs: &Rhs) -> RlstResult<Self::Output>;
226}
227
228/// Cholesky decomposition for positive definite matrices.
229pub trait Cholesky {
230    /// Item type of the array.
231    type Item;
232
233    /// Compute the Cholesky decomposition of a positive definite matrix.
234    ///
235    /// **Arguments:**
236    /// - `uplo`: Specifies whether the upper or lower triangular part of the matrix is stored.
237    fn cholesky(&self, uplo: UpLo) -> RlstResult<DynArray<Self::Item, 2>>;
238}
239
240/// Cholesky solver for positive definite systems.
241pub trait CholeskySolve<Rhs> {
242    /// The output type of the Cholesky solver.
243    type Output;
244
245    /// Solve a positive definite system of linear equations using Cholesky factorization.
246    ///
247    /// **Arguments:**
248    /// - `uplo`: Specifies whether the upper or lower triangular part of the matrix is stored.
249    fn cholesky_solve(&self, uplo: UpLo, rhs: &Rhs) -> RlstResult<Self::Output>;
250}
251
252/// Solve a triangular system of linear equations.
253pub trait SolveTriangular<Rhs> {
254    /// The output type of the triangular solver.
255    type Output;
256
257    /// Solve a triangular system of linear equations.
258    ///
259    /// **Arguments:**
260    /// - `uplo`: Specifies whether the upper or lower triangular part of the matrix is stored.
261    fn solve_triangular(&self, uplo: UpLo, rhs: &Rhs) -> RlstResult<Self::Output>;
262}