ndarray_linalg/lobpcg/
eig.rs

1use super::lobpcg::{lobpcg, LobpcgResult, Order};
2use crate::{generate, Scalar};
3use lax::Lapack;
4
5///! Implements truncated eigenvalue decomposition
6///
7use ndarray::prelude::*;
8use ndarray::stack;
9use ndarray::ScalarOperand;
10use num_traits::{Float, NumCast};
11
12/// Truncated eigenproblem solver
13///
14/// This struct wraps the LOBPCG algorithm and provides convenient builder-pattern access to
15/// parameter like maximal iteration, precision and constraint matrix. Furthermore it allows
16/// conversion into a iterative solver where each iteration step yields a new eigenvalue/vector
17/// pair.
18pub struct TruncatedEig<A: Scalar> {
19    order: Order,
20    problem: Array2<A>,
21    pub constraints: Option<Array2<A>>,
22    preconditioner: Option<Array2<A>>,
23    precision: f32,
24    maxiter: usize,
25}
26
27impl<A: Float + Scalar + ScalarOperand + Lapack + PartialOrd + Default> TruncatedEig<A> {
28    pub fn new(problem: Array2<A>, order: Order) -> TruncatedEig<A> {
29        TruncatedEig {
30            precision: 1e-5,
31            maxiter: problem.len_of(Axis(0)) * 2,
32            preconditioner: None,
33            constraints: None,
34            order,
35            problem,
36        }
37    }
38
39    pub fn precision(mut self, precision: f32) -> Self {
40        self.precision = precision;
41
42        self
43    }
44
45    pub fn maxiter(mut self, maxiter: usize) -> Self {
46        self.maxiter = maxiter;
47
48        self
49    }
50
51    pub fn orthogonal_to(mut self, constraints: Array2<A>) -> Self {
52        self.constraints = Some(constraints);
53
54        self
55    }
56
57    pub fn precondition_with(mut self, preconditioner: Array2<A>) -> Self {
58        self.preconditioner = Some(preconditioner);
59
60        self
61    }
62
63    // calculate the eigenvalues decompose
64    pub fn decompose(&self, num: usize) -> LobpcgResult<A> {
65        let x: Array2<f64> = generate::random((self.problem.len_of(Axis(0)), num));
66        let x = x.mapv(|x| NumCast::from(x).unwrap());
67
68        if let Some(ref preconditioner) = self.preconditioner {
69            lobpcg(
70                |y| self.problem.dot(&y),
71                x,
72                |mut y| {
73                    let p = preconditioner.dot(&y);
74                    y.assign(&p);
75                },
76                self.constraints.clone(),
77                self.precision,
78                self.maxiter,
79                self.order.clone(),
80            )
81        } else {
82            lobpcg(
83                |y| self.problem.dot(&y),
84                x,
85                |_| {},
86                self.constraints.clone(),
87                self.precision,
88                self.maxiter,
89                self.order.clone(),
90            )
91        }
92    }
93}
94
95impl<A: Float + Scalar + ScalarOperand + Lapack + PartialOrd + Default> IntoIterator
96    for TruncatedEig<A>
97{
98    type Item = (Array1<A>, Array2<A>);
99    type IntoIter = TruncatedEigIterator<A>;
100
101    fn into_iter(self) -> TruncatedEigIterator<A> {
102        TruncatedEigIterator {
103            step_size: 1,
104            remaining: self.problem.len_of(Axis(0)),
105            eig: self,
106        }
107    }
108}
109
110/// Truncate eigenproblem iterator
111///
112/// This wraps a truncated eigenproblem and provides an iterator where each step yields a new
113/// eigenvalue/vector pair. Useful for generating pairs until a certain condition is met.
114pub struct TruncatedEigIterator<A: Scalar> {
115    step_size: usize,
116    remaining: usize,
117    eig: TruncatedEig<A>,
118}
119
120impl<A: Float + Scalar + ScalarOperand + Lapack + PartialOrd + Default> Iterator
121    for TruncatedEigIterator<A>
122{
123    type Item = (Array1<A>, Array2<A>);
124
125    fn next(&mut self) -> Option<Self::Item> {
126        if self.remaining == 0 {
127            return None;
128        }
129
130        let step_size = usize::min(self.step_size, self.remaining);
131        let res = self.eig.decompose(step_size);
132
133        match res {
134            LobpcgResult::Ok(vals, vecs, norms) | LobpcgResult::Err(vals, vecs, norms, _) => {
135                // abort if any eigenproblem did not converge
136                for r_norm in norms {
137                    if r_norm > NumCast::from(0.1).unwrap() {
138                        return None;
139                    }
140                }
141
142                // add the new eigenvector to the internal constrain matrix
143                let new_constraints = if let Some(ref constraints) = self.eig.constraints {
144                    let eigvecs_arr: Vec<_> = constraints
145                        .columns()
146                        .into_iter()
147                        .chain(vecs.columns().into_iter())
148                        .collect();
149
150                    stack(Axis(1), &eigvecs_arr).unwrap()
151                } else {
152                    vecs.clone()
153                };
154
155                self.eig.constraints = Some(new_constraints);
156                self.remaining -= step_size;
157
158                Some((vals, vecs))
159            }
160            LobpcgResult::NoResult(_) => None,
161        }
162    }
163}
164
165#[cfg(test)]
166mod tests {
167    use super::Order;
168    use super::TruncatedEig;
169    use ndarray::{arr1, Array2};
170
171    #[test]
172    fn test_truncated_eig() {
173        let diag = arr1(&[
174            1., 2., 3., 4., 5., 6., 7., 8., 9., 10., 11., 12., 13., 14., 15., 16., 17., 18., 19.,
175            20.,
176        ]);
177        let a = Array2::from_diag(&diag);
178
179        let teig = TruncatedEig::new(a, Order::Largest)
180            .precision(1e-5)
181            .maxiter(500);
182
183        let res = teig
184            .into_iter()
185            .take(3)
186            .flat_map(|x| x.0.to_vec())
187            .collect::<Vec<_>>();
188        let ground_truth = vec![20., 19., 18.];
189
190        assert!(
191            ground_truth
192                .into_iter()
193                .zip(res.into_iter())
194                .map(|(x, y)| (x - y) * (x - y))
195                .sum::<f64>()
196                < 0.01
197        );
198    }
199}