nalgebra_lapack/
svd.rs

1#[cfg(feature = "serde-serialize")]
2use serde::{Deserialize, Serialize};
3
4use num::Signed;
5use std::cmp;
6
7use na::allocator::Allocator;
8use na::dimension::{Const, Dim, DimMin, DimMinimum, U1};
9use na::{DefaultAllocator, Matrix, OMatrix, OVector, Scalar};
10
11use lapack;
12
13/// The SVD decomposition of a general matrix.
14#[cfg_attr(feature = "serde-serialize", derive(Serialize, Deserialize))]
15#[cfg_attr(
16    feature = "serde-serialize",
17    serde(bound(serialize = "DefaultAllocator: Allocator<DimMinimum<R, C>> +
18                           Allocator<R, R> +
19                           Allocator<C, C>,
20         OMatrix<T, R, R>: Serialize,
21         OMatrix<T, C, C>: Serialize,
22         OVector<T, DimMinimum<R, C>>: Serialize"))
23)]
24#[cfg_attr(
25    feature = "serde-serialize",
26    serde(bound(deserialize = "DefaultAllocator: Allocator<DimMinimum<R, C>> +
27                             Allocator<R, R> +
28                             Allocator<C, C>,
29         OMatrix<T, R, R>: Deserialize<'de>,
30         OMatrix<T, C, C>: Deserialize<'de>,
31         OVector<T, DimMinimum<R, C>>: Deserialize<'de>"))
32)]
33#[derive(Clone, Debug)]
34pub struct SVD<T: Scalar, R: DimMin<C>, C: Dim>
35where
36    DefaultAllocator: Allocator<R, R> + Allocator<DimMinimum<R, C>> + Allocator<C, C>,
37{
38    /// The left-singular vectors `U` of this SVD.
39    pub u: OMatrix<T, R, R>, // TODO: should be OMatrix<T, R, DimMinimum<R, C>>
40    /// The right-singular vectors `V^t` of this SVD.
41    pub vt: OMatrix<T, C, C>, // TODO: should be OMatrix<T, DimMinimum<R, C>, C>
42    /// The singular values of this SVD.
43    pub singular_values: OVector<T, DimMinimum<R, C>>,
44}
45
46impl<T: Scalar + Copy, R: DimMin<C>, C: Dim> Copy for SVD<T, R, C>
47where
48    DefaultAllocator: Allocator<C, C> + Allocator<R, R> + Allocator<DimMinimum<R, C>>,
49    OMatrix<T, R, R>: Copy,
50    OMatrix<T, C, C>: Copy,
51    OVector<T, DimMinimum<R, C>>: Copy,
52{
53}
54
55/// Trait implemented by floats (`f32`, `f64`) and complex floats (`Complex<f32>`, `Complex<f64>`)
56/// supported by the Singular Value Decompotition.
57pub trait SVDScalar<R: DimMin<C>, C: Dim>: Scalar
58where
59    DefaultAllocator:
60        Allocator<R, R> + Allocator<R, C> + Allocator<DimMinimum<R, C>> + Allocator<C, C>,
61{
62    /// Computes the SVD decomposition of `m`.
63    fn compute(m: OMatrix<Self, R, C>) -> Option<SVD<Self, R, C>>;
64}
65
66impl<T: SVDScalar<R, C>, R: DimMin<C>, C: Dim> SVD<T, R, C>
67where
68    DefaultAllocator:
69        Allocator<R, R> + Allocator<R, C> + Allocator<DimMinimum<R, C>> + Allocator<C, C>,
70{
71    /// Computes the Singular Value Decomposition of `matrix`.
72    pub fn new(m: OMatrix<T, R, C>) -> Option<Self> {
73        T::compute(m)
74    }
75}
76
77macro_rules! svd_impl(
78    ($t: ty, $lapack_func: path) => (
79        impl<R: Dim, C: Dim> SVDScalar<R, C> for $t
80                where R: DimMin<C>,
81                      DefaultAllocator: Allocator<R, C> +
82                                        Allocator<R, R> +
83                                        Allocator<C, C> +
84                                        Allocator<DimMinimum<R, C>> {
85
86            fn compute(mut m: OMatrix<$t, R, C>) -> Option<SVD<$t, R, C>> {
87                let (nrows, ncols) = m.shape_generic();
88
89                if nrows.value() == 0 || ncols.value() == 0 {
90                    return None;
91                }
92
93                let job = b'A';
94
95                let lda = nrows.value() as i32;
96
97                let mut u  = Matrix::zeros_generic(nrows, nrows);
98                let mut s  = Matrix::zeros_generic(nrows.min(ncols), Const::<1>);
99                let mut vt = Matrix::zeros_generic(ncols, ncols);
100
101                let ldu  = nrows.value();
102                let ldvt = ncols.value();
103
104                let mut work  = [ 0.0 ];
105                let mut lwork = -1 as i32;
106                let mut info  = 0;
107                let mut iwork = vec![0; 8 * cmp::min(nrows.value(), ncols.value())];
108
109                unsafe {
110                    $lapack_func(job, nrows.value() as i32, ncols.value() as i32, m.as_mut_slice(),
111                    lda, &mut s.as_mut_slice(), u.as_mut_slice(), ldu as i32, vt.as_mut_slice(),
112                    ldvt as i32, &mut work, lwork, &mut iwork, &mut info);
113                }
114                lapack_check!(info);
115
116                lwork = work[0] as i32;
117                let mut work = vec![0.0; lwork as usize];
118
119                unsafe {
120                $lapack_func(job, nrows.value() as i32, ncols.value() as i32, m.as_mut_slice(),
121                    lda, &mut s.as_mut_slice(), u.as_mut_slice(), ldu as i32, vt.as_mut_slice(),
122                    ldvt as i32, &mut work, lwork, &mut iwork, &mut info);
123                }
124
125                lapack_check!(info);
126
127                Some(SVD { u: u, singular_values: s, vt: vt })
128            }
129        }
130
131        impl<R: DimMin<C>, C: Dim> SVD<$t, R, C>
132            // TODO: All those bounds…
133            where DefaultAllocator: Allocator<R, C>                 +
134                                    Allocator<C, R>                 +
135                                    Allocator<U1, R>                +
136                                    Allocator<U1, C>                +
137                                    Allocator<R, R>                 +
138                                    Allocator<DimMinimum<R, C>> +
139                                    Allocator<DimMinimum<R, C>, R>  +
140                                    Allocator<DimMinimum<R, C>, C>  +
141                                    Allocator<R, DimMinimum<R, C>>  +
142                                    Allocator<C, C> {
143            /// Reconstructs the matrix from its decomposition.
144            ///
145            /// Useful if some components (e.g. some singular values) of this decomposition have
146            /// been manually changed by the user.
147            #[inline]
148            pub fn recompose(self) -> OMatrix<$t, R, C> {
149                let nrows           = self.u.shape_generic().0;
150                let ncols           = self.vt.shape_generic().1;
151                let min_nrows_ncols = nrows.min(ncols);
152
153                let mut res: OMatrix<_, R, C> = Matrix::zeros_generic(nrows, ncols);
154
155                {
156                    let mut sres = res.generic_view_mut((0, 0), (min_nrows_ncols, ncols));
157                    sres.copy_from(&self.vt.rows_generic(0, min_nrows_ncols));
158
159                    for i in 0 .. min_nrows_ncols.value() {
160                        let eigval  = self.singular_values[i];
161                        let mut row = sres.row_mut(i);
162                        row *= eigval;
163                    }
164                }
165
166                self.u * res
167            }
168
169            /// Computes the pseudo-inverse of the decomposed matrix.
170            ///
171            /// All singular value below epsilon will be set to zero instead of being inverted.
172            #[inline]
173            #[must_use]
174            pub fn pseudo_inverse(&self, epsilon: $t) -> OMatrix<$t, C, R> {
175                let nrows           = self.u.shape_generic().0;
176                let ncols           = self.vt.shape_generic().1;
177                let min_nrows_ncols = nrows.min(ncols);
178
179                let mut res: OMatrix<_, C, R> = Matrix::zeros_generic(ncols, nrows);
180
181                {
182                    let mut sres = res.generic_view_mut((0, 0), (min_nrows_ncols, nrows));
183                    self.u.columns_generic(0, min_nrows_ncols).transpose_to(&mut sres);
184
185                    for i in 0 .. min_nrows_ncols.value() {
186                        let eigval  = self.singular_values[i];
187                        let mut row = sres.row_mut(i);
188
189                        if eigval.abs() > epsilon {
190                            row /= eigval
191                        }
192                        else {
193                            row.fill(0.0);
194                        }
195                    }
196                }
197
198                self.vt.tr_mul(&res)
199            }
200
201            /// The rank of the decomposed matrix.
202            ///
203            /// This is the number of singular values that are not too small (i.e. greater than
204            /// the given `epsilon`).
205            #[inline]
206            #[must_use]
207            pub fn rank(&self, epsilon: $t) -> usize {
208                let mut i = 0;
209
210                for e in self.singular_values.as_slice().iter() {
211                    if e.abs() > epsilon {
212                        i += 1;
213                    }
214                }
215
216                i
217            }
218
219            // TODO: add methods to retrieve the null-space and column-space? (Respectively
220            // corresponding to the zero and non-zero singular values).
221        }
222    );
223);
224
225/*
226macro_rules! svd_complex_impl(
227    ($name: ident, $t: ty, $lapack_func: path) => (
228        impl SVDScalar for Complex<$t> {
229            fn compute<R: Dim, C: Dim, S>(mut m: Matrix<$t, R, C, S>) -> Option<SVD<$t, R, C, S::Alloc>>
230            Option<(OMatrix<Complex<$t>, R, S::Alloc>,
231                    OVector<$t, DimMinimum<R, C>, S::Alloc>,
232                    OMatrix<Complex<$t>, C, S::Alloc>)>
233            where R: DimMin<C>,
234                  S: ContiguousStorage<Complex<$t>, R, C>,
235                  S::Alloc: OwnedAllocator<Complex<$t>, R, C, S> +
236                            Allocator<R, R>         +
237                            Allocator<C, C>         +
238                            Allocator<DimMinimum<R, C>> {
239            let (nrows, ncols) = m.shape_generic();
240
241            if nrows.value() == 0 || ncols.value() == 0 {
242                return None;
243            }
244
245            let jobu  = b'A';
246            let jobvt = b'A';
247
248            let lda = nrows.value() as i32;
249            let min_nrows_ncols = nrows.min(ncols);
250
251
252            let mut u  = Matrix::zeros_generic(nrows, nrows);
253            let mut s  = Matrix::zeros_generic(min_nrows_ncols, U1);
254            let mut vt = Matrix::zeros_generic(ncols, ncols);
255
256            let ldu  = nrows.value();
257            let ldvt = ncols.value();
258
259            let mut work  = [ Complex::new(0.0, 0.0) ];
260            let mut lwork = -1 as i32;
261            let mut rwork = vec![ 0.0; (5 * min_nrows_ncols.value()) ];
262            let mut info  = 0;
263
264            $lapack_func(jobu, jobvt, nrows.value() as i32, ncols.value() as i32, m.as_mut_slice(),
265                lda, s.as_mut_slice(), u.as_mut_slice(), ldu as i32, vt.as_mut_slice(),
266                ldvt as i32, &mut work, lwork, &mut rwork, &mut info);
267            lapack_check!(info);
268
269            lwork = work[0].re as i32;
270            let mut work = vec![Complex::new(0.0, 0.0); lwork as usize];
271
272            $lapack_func(jobu, jobvt, nrows.value() as i32, ncols.value() as i32, m.as_mut_slice(),
273                lda, s.as_mut_slice(), u.as_mut_slice(), ldu as i32, vt.as_mut_slice(),
274                ldvt as i32, &mut work, lwork, &mut rwork, &mut info);
275            lapack_check!(info);
276
277            Some((u, s, vt))
278        }
279    );
280);
281*/
282
283svd_impl!(f32, lapack::sgesdd);
284svd_impl!(f64, lapack::dgesdd);
285// svd_complex_impl!(lapack_svd_complex_f32, f32, lapack::cgesvd);
286// svd_complex_impl!(lapack_svd_complex_f64, f64, lapack::zgesvd);