ndarray_linalg/lobpcg/
eig.rs1use super::lobpcg::{lobpcg, LobpcgResult, Order};
2use crate::{generate, Scalar};
3use lax::Lapack;
4
5use ndarray::prelude::*;
8use ndarray::stack;
9use ndarray::ScalarOperand;
10use num_traits::{Float, NumCast};
11
12pub struct TruncatedEig<A: Scalar> {
19 order: Order,
20 problem: Array2<A>,
21 pub constraints: Option<Array2<A>>,
22 preconditioner: Option<Array2<A>>,
23 precision: f32,
24 maxiter: usize,
25}
26
27impl<A: Float + Scalar + ScalarOperand + Lapack + PartialOrd + Default> TruncatedEig<A> {
28 pub fn new(problem: Array2<A>, order: Order) -> TruncatedEig<A> {
29 TruncatedEig {
30 precision: 1e-5,
31 maxiter: problem.len_of(Axis(0)) * 2,
32 preconditioner: None,
33 constraints: None,
34 order,
35 problem,
36 }
37 }
38
39 pub fn precision(mut self, precision: f32) -> Self {
40 self.precision = precision;
41
42 self
43 }
44
45 pub fn maxiter(mut self, maxiter: usize) -> Self {
46 self.maxiter = maxiter;
47
48 self
49 }
50
51 pub fn orthogonal_to(mut self, constraints: Array2<A>) -> Self {
52 self.constraints = Some(constraints);
53
54 self
55 }
56
57 pub fn precondition_with(mut self, preconditioner: Array2<A>) -> Self {
58 self.preconditioner = Some(preconditioner);
59
60 self
61 }
62
63 pub fn decompose(&self, num: usize) -> LobpcgResult<A> {
65 let x: Array2<f64> = generate::random((self.problem.len_of(Axis(0)), num));
66 let x = x.mapv(|x| NumCast::from(x).unwrap());
67
68 if let Some(ref preconditioner) = self.preconditioner {
69 lobpcg(
70 |y| self.problem.dot(&y),
71 x,
72 |mut y| {
73 let p = preconditioner.dot(&y);
74 y.assign(&p);
75 },
76 self.constraints.clone(),
77 self.precision,
78 self.maxiter,
79 self.order.clone(),
80 )
81 } else {
82 lobpcg(
83 |y| self.problem.dot(&y),
84 x,
85 |_| {},
86 self.constraints.clone(),
87 self.precision,
88 self.maxiter,
89 self.order.clone(),
90 )
91 }
92 }
93}
94
95impl<A: Float + Scalar + ScalarOperand + Lapack + PartialOrd + Default> IntoIterator
96 for TruncatedEig<A>
97{
98 type Item = (Array1<A>, Array2<A>);
99 type IntoIter = TruncatedEigIterator<A>;
100
101 fn into_iter(self) -> TruncatedEigIterator<A> {
102 TruncatedEigIterator {
103 step_size: 1,
104 remaining: self.problem.len_of(Axis(0)),
105 eig: self,
106 }
107 }
108}
109
110pub struct TruncatedEigIterator<A: Scalar> {
115 step_size: usize,
116 remaining: usize,
117 eig: TruncatedEig<A>,
118}
119
120impl<A: Float + Scalar + ScalarOperand + Lapack + PartialOrd + Default> Iterator
121 for TruncatedEigIterator<A>
122{
123 type Item = (Array1<A>, Array2<A>);
124
125 fn next(&mut self) -> Option<Self::Item> {
126 if self.remaining == 0 {
127 return None;
128 }
129
130 let step_size = usize::min(self.step_size, self.remaining);
131 let res = self.eig.decompose(step_size);
132
133 match res {
134 LobpcgResult::Ok(vals, vecs, norms) | LobpcgResult::Err(vals, vecs, norms, _) => {
135 for r_norm in norms {
137 if r_norm > NumCast::from(0.1).unwrap() {
138 return None;
139 }
140 }
141
142 let new_constraints = if let Some(ref constraints) = self.eig.constraints {
144 let eigvecs_arr: Vec<_> = constraints
145 .columns()
146 .into_iter()
147 .chain(vecs.columns().into_iter())
148 .collect();
149
150 stack(Axis(1), &eigvecs_arr).unwrap()
151 } else {
152 vecs.clone()
153 };
154
155 self.eig.constraints = Some(new_constraints);
156 self.remaining -= step_size;
157
158 Some((vals, vecs))
159 }
160 LobpcgResult::NoResult(_) => None,
161 }
162 }
163}
164
165#[cfg(test)]
166mod tests {
167 use super::Order;
168 use super::TruncatedEig;
169 use ndarray::{arr1, Array2};
170
171 #[test]
172 fn test_truncated_eig() {
173 let diag = arr1(&[
174 1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19.,
175 20.,
176 ]);
177 let a = Array2::from_diag(&diag);
178
179 let teig = TruncatedEig::new(a, Order::Largest)
180 .precision(1e-5)
181 .maxiter(500);
182
183 let res = teig
184 .into_iter()
185 .take(3)
186 .flat_map(|x| x.0.to_vec())
187 .collect::<Vec<_>>();
188 let ground_truth = vec![20., 19., 18.];
189
190 assert!(
191 ground_truth
192 .into_iter()
193 .zip(res.into_iter())
194 .map(|(x, y)| (x - y) * (x - y))
195 .sum::<f64>()
196 < 0.01
197 );
198 }
199}