1use std::f64::EPSILON;
2
3use arpack_ng_sys::*;
4use ndarray::prelude::*;
5use num_complex::Complex64;
6
7use crate::{Arpack, Error, Which, MUTEX, ZERO};
8
9impl Arpack for Array2<Complex64> {
10 type Result = Array1<Complex64>;
11 type ResultVec = (Array1<Complex64>, Array2<Complex64>);
12
13 fn eigenvalues(
14 &self,
15 which: &Which,
16 nev: usize,
17 ncv: usize,
18 maxiter: usize,
19 ) -> Result<Self::Result, Error> {
20 if !self.is_square() {
21 return Err(Error::NonSquare);
22 }
23 let n = self.dim().0;
24 let (val, _) = arpack_c64(
25 |v1, mut v2| v2.assign(&self.dot(&v1)),
26 n,
27 which.as_str(),
28 nev,
29 ncv,
30 maxiter,
31 true,
32 )?;
33 Ok(val)
34 }
35
36 fn eigenvectors(
37 &self,
38 which: &Which,
39 nev: usize,
40 ncv: usize,
41 maxiter: usize,
42 ) -> Result<Self::ResultVec, Error> {
43 if !self.is_square() {
44 return Err(Error::NonSquare);
45 }
46 let n = self.dim().0;
47 arpack_c64(
48 |v1, mut v2| v2.assign(&self.dot(&v1)),
49 n,
50 which.as_str(),
51 nev,
52 ncv,
53 maxiter,
54 true,
55 )
56 }
57}
58
59pub fn eigenvalues<F>(
60 av: F,
61 n: usize,
62 which: &Which,
63 nev: usize,
64 ncv: usize,
65 maxiter: usize,
66) -> Result<Array1<Complex64>, Error>
67where
68 F: FnMut(ArrayView1<Complex64>, ArrayViewMut1<Complex64>),
69{
70 let (res, _) = arpack_c64(av, n, which.as_str(), nev, ncv, maxiter, true)?;
71 Ok(res)
72}
73
74pub fn eigenvectors<F>(
75 av: F,
76 n: usize,
77 which: &Which,
78 nev: usize,
79 ncv: usize,
80 maxiter: usize,
81) -> Result<(Array1<Complex64>, Array2<Complex64>), Error>
82where
83 F: FnMut(ArrayView1<Complex64>, ArrayViewMut1<Complex64>),
84{
85 arpack_c64(av, n, which.as_str(), nev, ncv, maxiter, true)
86}
87
88fn arpack_c64<F>(
89 mut av: F,
90 n: usize,
91 which: &str,
92 nev: usize,
93 ncv: usize,
94 maxiter: usize,
95 vectors: bool,
96) -> Result<(Array1<Complex64>, Array2<Complex64>), Error>
97where
98 F: FnMut(ArrayView1<Complex64>, ArrayViewMut1<Complex64>),
99{
100 let g = MUTEX.lock().unwrap();
101 let mut ido = 0;
102 let mut resid: Array1<Complex64> = Array1::zeros(n);
103 let mut v: Array2<Complex64> = Array2::zeros((n, ncv));
104 let mut iparam = [0; 11];
105 iparam[0] = 1;
106 iparam[2] = maxiter as i32;
107 iparam[6] = 1;
108 let mut ipntr = [0; 14];
109 let mut workd = Array1::zeros(3 * n);
110 let lworkl = 3 * ncv.pow(2) + 6 * ncv;
111 let mut workl: Array1<Complex64> = Array1::zeros(lworkl);
112 let mut rwork = vec![0.; ncv];
113 let mut info = 0;
114 while ido != 99 {
115 unsafe {
116 znaupd_c(
117 &mut ido,
118 "I".as_ptr() as *const i8,
119 n as i32,
120 which.as_ptr() as *const i8,
121 nev as i32,
122 EPSILON,
123 resid.as_mut_ptr() as *mut __BindgenComplex<f64>,
124 ncv as i32,
125 v.as_mut_ptr() as *mut __BindgenComplex<f64>,
126 n as i32,
127 iparam.as_mut_ptr(),
128 ipntr.as_mut_ptr(),
129 workd.as_mut_ptr() as *mut __BindgenComplex<f64>,
130 workl.as_mut_ptr() as *mut __BindgenComplex<f64>,
131 lworkl as i32,
132 rwork.as_mut_ptr(),
133 &mut info,
134 );
135 }
136 if (ido == -1) || (ido == 1) {
137 let v = workd
138 .slice(s![ipntr[0] as usize - 1..ipntr[0] as usize + n - 1])
139 .to_owned();
140 av(
141 v.view(),
142 workd.slice_mut(s![ipntr[1] as usize - 1..ipntr[1] as usize + n - 1]),
143 );
144 }
145 }
146 match info {
147 0 | 1 | 2 => {}
148 -1 => return Err(Error::IllegalParameters("N must be positive.".to_string())),
149 -2 => {
150 return Err(Error::IllegalParameters(
151 "NEV must be positive.".to_string(),
152 ))
153 }
154 -3 => {
155 return Err(Error::IllegalParameters(
156 "NCV-NEV >= 2 and less than or equal to N.".to_string(),
157 ))
158 }
159 -4 => {
160 return Err(Error::IllegalParameters(
161 "Maximum iterations must be greater than 0.".to_string(),
162 ))
163 }
164 -5 => {
165 return Err(Error::IllegalParameters(
166 "Maximum iterations must be greater than 0.".to_string(),
167 ))
168 }
169 i => return Err(Error::Other(i)),
170 }
171
172 let select = vec![false as i32; ncv];
173 let mut d: Array1<Complex64> = Array1::zeros(nev + 1);
174 let mut z: Array2<Complex64> = Array2::zeros((n, nev));
175 let mut workev: Array1<Complex64> = Array1::zeros(2 * ncv);
176 unsafe {
177 zneupd_c(
178 vectors as i32,
179 "A".as_ptr() as *const i8,
180 select.as_ptr(),
181 d.as_mut_ptr() as *mut __BindgenComplex<f64>,
182 z.as_mut_ptr() as *mut __BindgenComplex<f64>,
183 n as i32,
184 ZERO,
185 workev.as_mut_ptr() as *mut __BindgenComplex<f64>,
186 "I".as_ptr() as *const i8,
187 n as i32,
188 "LR".as_ptr() as *const i8,
189 nev as i32,
190 EPSILON,
191 resid.as_mut_ptr() as *mut __BindgenComplex<f64>,
192 ncv as i32,
193 v.as_mut_ptr() as *mut __BindgenComplex<f64>,
194 n as i32,
195 iparam.as_mut_ptr(),
196 ipntr.as_mut_ptr(),
197 workd.as_mut_ptr() as *mut __BindgenComplex<f64>,
198 workl.as_mut_ptr() as *mut __BindgenComplex<f64>,
199 lworkl as i32,
200 rwork.as_mut_ptr(),
201 &mut info,
202 );
203 }
204 drop(g);
205 Ok((d.slice(s![0..nev]).to_owned(), z))
206}