1use ndarray::{Array, Array1, Array2};
4
5use lax::{layout::MatrixLayout, JobSvd, Lapack};
6use crate::tools::svdapprox::*;
9
10pub(crate) const FULL_MAT_REPR: usize = 5000;
11
12pub(crate) const FULL_SVD_SIZE_LIMIT: usize = 5000;
13
14#[derive(Clone)]
18pub(crate) struct GraphLaplacian {
19 sym_laplacian: MatRepr<f32>,
21 pub(crate) degrees: Array1<f32>,
23 pub(crate) svd_res: Option<SvdResult<f32>>,
25}
26
27impl GraphLaplacian {
28 pub fn new(sym_laplacian: MatRepr<f32>, degrees: Array1<f32>) -> Self {
29 GraphLaplacian {
30 sym_laplacian,
31 degrees,
32 svd_res: None,
33 }
34 } #[inline]
37 fn is_csr(&self) -> bool {
38 self.sym_laplacian.is_csr()
39 } fn get_nbrow(&self) -> usize {
42 self.degrees.len()
43 }
44
45 fn do_full_svd(&mut self) -> Result<SvdResult<f32>, String> {
46 log::info!("GraphLaplacian doing full svd");
48 log::debug!("memory : {:?}", memory_stats::memory_stats().unwrap());
49 let b = self.sym_laplacian.get_full_mut().unwrap();
50 log::trace!(
51 "GraphLaplacian ... size nbrow {} nbcol {} ",
52 b.shape()[0],
53 b.shape()[1]
54 );
55 svd_f32(b)
57 } fn do_approx_svd(&mut self, asked_dim: usize) -> Result<SvdResult<f32>, String> {
61 assert!(asked_dim >= 2);
62 log::info!(
67 "got laplacian, going to approximated svd ... asked_dim : {}",
68 asked_dim
69 );
70 let mut svdapprox = SvdApprox::new(&self.sym_laplacian);
71 let rank = 20;
75 let nbiter = 5;
76 log::trace!("asking svd, RangeRank rank : {}, nbiter : {}", rank, nbiter);
77 let svdmode = RangeApproxMode::RANK(RangeRank::new(rank, nbiter));
79 let svd_res = svdapprox.direct_svd(svdmode);
80 log::trace!("exited svd");
81 if svd_res.is_err() {
82 log::error!("svd approximation failed");
83 std::panic!();
84 }
85 self.check_norms(svd_res.as_ref().unwrap());
86 svd_res
87 } pub fn do_svd(&mut self, asked_dim: usize) -> Result<SvdResult<f32>, String> {
90 if !self.is_csr() && self.get_nbrow() <= FULL_SVD_SIZE_LIMIT {
91 self.do_full_svd()
93 } else {
94 self.do_approx_svd(asked_dim)
95 }
96 } #[allow(unused)]
99 pub(crate) fn check_norms(&self, svd_res: &SvdResult<f32>) {
100 log::trace!("in of check_norms");
101 let u = svd_res.get_u_ref().unwrap();
103 log::debug!("checking U norms , dim : {:?}", u.dim());
104 let (nb_rows, nb_cols) = u.dim();
105 for i in 0..nb_cols.min(3) {
106 let norm = norm_frobenius_full(&u.column(i));
107 log::debug!(" vector {} norm {:.2e} ", i, norm);
108 }
109 log::trace!("end of check_norms");
110 }
111} pub(crate) fn svd_f32(b: &mut Array2<f32>) -> Result<SvdResult<f32>, String> {
119 let layout = MatrixLayout::C {
120 row: b.shape()[0] as i32,
121 lda: b.shape()[1] as i32,
122 };
123 let slice_for_svd_opt = b.as_slice_mut();
124 if slice_for_svd_opt.is_none() {
125 log::error!("direct_svd Matrix cannot be transformed into a slice : not contiguous or not in standard order");
126 return Err(String::from("not contiguous or not in standard order"));
127 }
128 log::trace!("direct_svd calling svddc driver");
130 let res_svd_b = f32::svddc(layout, JobSvd::Some, slice_for_svd_opt.unwrap());
131 if res_svd_b.is_err() {
132 log::error!("direct_svd, svddc failed");
133 };
134 let res_svd_b = res_svd_b.unwrap();
140 let r = res_svd_b.s.len();
141 let m = b.shape()[0];
142 let s: Array1<f32> = res_svd_b.s.iter().copied().collect::<Array1<f32>>();
144 let s_u: Option<Array2<f32>>;
152 if let Some(u_vec) = res_svd_b.u {
153 let u_1 = Array::from_shape_vec((m, r), u_vec).unwrap();
154 s_u = Some(u_1);
155 } else {
156 s_u = None;
157 }
158 Ok(SvdResult {
160 s: Some(s),
161 u: s_u,
162 vt: None,
163 })
164}
165
166#[cfg(test)]
169mod tests {
170
171 use super::*;
175
176 fn log_init_test() {
177 let _ = env_logger::builder().is_test(true).try_init();
178 }
179
180 #[test]
182 fn test_svd_wiki_rank_svd_f32() {
183 log_init_test();
185 log::info!("\n\n test_svd_wiki");
187 let row_0: [f32; 5] = [1., 0., 0., 0., 2.];
190 let row_1: [f32; 5] = [0., 0., 3., 0., 0.];
191 let row_2: [f32; 5] = [0., 0., 0., 0., 0.];
192 let row_3: [f32; 5] = [0., 2., 0., 0., 0.];
193
194 let mut mat = ndarray::arr2(
195 &[row_0, row_1, row_2, row_3], );
197 let epsil: f32 = 1.0E-5;
199 let res = svd_f32(&mut mat).unwrap();
200 let computed_s = res.get_sigma().as_ref().unwrap();
201 let sigma = ndarray::arr1(&[3., (5f32).sqrt(), 2., 0.]);
202 for i in 0..computed_s.len() {
203 log::debug! {"sp i exact : {}, computed {}", sigma[i], computed_s[i]};
204 let test = if sigma[i] > 0. {
205 ((1. - computed_s[i] / sigma[i]).abs() as f32) < epsil
206 } else {
207 ((sigma[i] - computed_s[i]).abs() as f32) < epsil
208 };
209 assert!(test);
210 }
211 }
212}