use std::os::raw::c_int;
use arpack_sys::{__BindgenComplex, cnaupd_c, cneupd_c, znaupd_c, zneupd_c};
use num_complex::{Complex32, Complex64};
use crate::error::Error;
use crate::lock::lock;
use crate::solution::{EigSolution, usize_from_iparam};
#[derive(Clone, Debug)]
pub struct Options {
pub tol: f64,
pub max_iter: usize,
pub ncv: Option<usize>,
}
impl Default for Options {
fn default() -> Self {
Self {
tol: 0.0,
max_iter: 300,
ncv: None,
}
}
}
pub fn smallest_eigenpair_c64<F>(
n: usize,
matvec: F,
options: &Options,
) -> Result<EigSolution<Complex64>, Error>
where
F: FnMut(&[Complex64], &mut [Complex64]),
{
let nev: c_int = 1;
let nev_usize = nev as usize;
if n < nev_usize + 3 {
return Err(Error::InvalidParam(
"n too small for complex Arnoldi (require n >= nev + 3)",
));
}
let ncv = options
.ncv
.unwrap_or_else(|| (2 * nev_usize + 4).min(n - 1).max(nev_usize + 2));
let n_i32 = c_int_from_usize(n)?;
let ncv_i32 = c_int_from_usize(ncv)?;
let max_iter_i32 = c_int_from_usize(options.max_iter)?;
if !(nev > 0 && ncv_i32 >= nev + 2 && ncv_i32 < n_i32) {
return Err(Error::InvalidParam(
"require 0 < nev, nev + 2 <= ncv, and ncv < n",
));
}
if max_iter_i32 <= 0 {
return Err(Error::InvalidParam("max_iter must be positive"));
}
let v_len = n
.checked_mul(ncv)
.ok_or(Error::InvalidParam("n * ncv overflows usize"))?;
let workd_len = n
.checked_mul(3)
.ok_or(Error::InvalidParam("3 * n overflows usize"))?;
let ncv_sq = ncv
.checked_mul(ncv)
.ok_or(Error::InvalidParam("ncv * ncv overflows usize"))?;
let three_ncv_sq = ncv_sq
.checked_mul(3)
.ok_or(Error::InvalidParam("3 * ncv^2 overflows usize"))?;
let five_ncv = ncv
.checked_mul(5)
.ok_or(Error::InvalidParam("5 * ncv overflows usize"))?;
let lworkl = three_ncv_sq
.checked_add(five_ncv)
.ok_or(Error::InvalidParam("3*ncv^2 + 5*ncv overflows usize"))?;
let workev_len = ncv
.checked_mul(2)
.ok_or(Error::InvalidParam("2 * ncv overflows usize"))?;
let lworkl_i32 = c_int_from_usize(lworkl)?;
let zero = Complex64::new(0.0, 0.0);
let mut resid = vec![zero; n];
let mut v = vec![zero; v_len];
let ldv = n_i32;
let mut iparam = [0i32; 11];
iparam[0] = 1; iparam[2] = max_iter_i32;
iparam[3] = 1; iparam[6] = 1; let mut ipntr = [0i32; 14];
let mut workd = vec![zero; workd_len];
let mut workl = vec![zero; lworkl];
let mut rwork = vec![0.0f64; ncv];
let bmat = c"I".as_ptr();
let which = c"SR".as_ptr();
let _guard = lock();
let mut ido: c_int = 0;
let mut info: c_int = 0;
let mut matvec = matvec;
let mut x_buf = vec![zero; n];
loop {
unsafe {
znaupd_c(
&mut ido,
bmat,
n_i32,
which,
nev,
options.tol,
resid.as_mut_ptr() as *mut __BindgenComplex<f64>,
ncv_i32,
v.as_mut_ptr() as *mut __BindgenComplex<f64>,
ldv,
iparam.as_mut_ptr(),
ipntr.as_mut_ptr(),
workd.as_mut_ptr() as *mut __BindgenComplex<f64>,
workl.as_mut_ptr() as *mut __BindgenComplex<f64>,
lworkl_i32,
rwork.as_mut_ptr(),
&mut info,
);
}
match ido {
-1 | 1 => {
let x_off = (ipntr[0] - 1) as usize;
let y_off = (ipntr[1] - 1) as usize;
debug_assert!(x_off + n <= workd.len() && y_off + n <= workd.len());
x_buf.copy_from_slice(&workd[x_off..x_off + n]);
matvec(&x_buf, &mut workd[y_off..y_off + n]);
}
99 => break,
other => return Err(Error::UnexpectedIdo(other)),
}
}
if info == 1 {
return Err(Error::MaxIterReached {
iters: usize_from_iparam(iparam[2]),
nconv: usize_from_iparam(iparam[4]),
n_matvec: usize_from_iparam(iparam[8]),
});
}
if info != 0 {
return Err(Error::AupdFailed(info));
}
let rvec: c_int = 1;
let howmny = c"A".as_ptr();
let mut select = vec![0i32; ncv];
let mut d = vec![zero; nev_usize];
let sigma = __BindgenComplex { re: 0.0, im: 0.0 };
let mut workev = vec![zero; workev_len];
let mut info_eup: c_int = 0;
unsafe {
zneupd_c(
rvec,
howmny,
select.as_mut_ptr(),
d.as_mut_ptr() as *mut __BindgenComplex<f64>,
v.as_mut_ptr() as *mut __BindgenComplex<f64>,
ldv,
sigma,
workev.as_mut_ptr() as *mut __BindgenComplex<f64>,
bmat,
n_i32,
which,
nev,
options.tol,
resid.as_mut_ptr() as *mut __BindgenComplex<f64>,
ncv_i32,
v.as_mut_ptr() as *mut __BindgenComplex<f64>,
ldv,
iparam.as_mut_ptr(),
ipntr.as_mut_ptr(),
workd.as_mut_ptr() as *mut __BindgenComplex<f64>,
workl.as_mut_ptr() as *mut __BindgenComplex<f64>,
lworkl_i32,
rwork.as_mut_ptr(),
&mut info_eup,
);
}
if info_eup != 0 {
return Err(Error::EupdFailed(info_eup));
}
let value = d[0];
let mut vector = vec![zero; n];
vector.copy_from_slice(&v[..n]);
Ok(EigSolution {
eigenvalue: value,
eigenvector: vector,
iters: usize_from_iparam(iparam[2]),
nconv: usize_from_iparam(iparam[4]),
n_matvec: usize_from_iparam(iparam[8]),
})
}
pub fn smallest_eigenpair_c32<F>(
n: usize,
matvec: F,
options: &Options,
) -> Result<EigSolution<Complex32>, Error>
where
F: FnMut(&[Complex32], &mut [Complex32]),
{
let nev: c_int = 1;
let nev_usize = nev as usize;
if n < nev_usize + 3 {
return Err(Error::InvalidParam(
"n too small for complex Arnoldi (require n >= nev + 3)",
));
}
let ncv = options
.ncv
.unwrap_or_else(|| (2 * nev_usize + 4).min(n - 1).max(nev_usize + 2));
let n_i32 = c_int_from_usize(n)?;
let ncv_i32 = c_int_from_usize(ncv)?;
let max_iter_i32 = c_int_from_usize(options.max_iter)?;
if !(nev > 0 && ncv_i32 >= nev + 2 && ncv_i32 < n_i32) {
return Err(Error::InvalidParam(
"require 0 < nev, nev + 2 <= ncv, and ncv < n",
));
}
if max_iter_i32 <= 0 {
return Err(Error::InvalidParam("max_iter must be positive"));
}
let v_len = n
.checked_mul(ncv)
.ok_or(Error::InvalidParam("n * ncv overflows usize"))?;
let workd_len = n
.checked_mul(3)
.ok_or(Error::InvalidParam("3 * n overflows usize"))?;
let ncv_sq = ncv
.checked_mul(ncv)
.ok_or(Error::InvalidParam("ncv * ncv overflows usize"))?;
let three_ncv_sq = ncv_sq
.checked_mul(3)
.ok_or(Error::InvalidParam("3 * ncv^2 overflows usize"))?;
let five_ncv = ncv
.checked_mul(5)
.ok_or(Error::InvalidParam("5 * ncv overflows usize"))?;
let lworkl = three_ncv_sq
.checked_add(five_ncv)
.ok_or(Error::InvalidParam("3*ncv^2 + 5*ncv overflows usize"))?;
let workev_len = ncv
.checked_mul(2)
.ok_or(Error::InvalidParam("2 * ncv overflows usize"))?;
let lworkl_i32 = c_int_from_usize(lworkl)?;
let tol = options.tol as f32;
let zero = Complex32::new(0.0, 0.0);
let mut resid = vec![zero; n];
let mut v = vec![zero; v_len];
let ldv = n_i32;
let mut iparam = [0i32; 11];
iparam[0] = 1;
iparam[2] = max_iter_i32;
iparam[3] = 1;
iparam[6] = 1;
let mut ipntr = [0i32; 14];
let mut workd = vec![zero; workd_len];
let mut workl = vec![zero; lworkl];
let mut rwork = vec![0.0f32; ncv];
let bmat = c"I".as_ptr();
let which = c"SR".as_ptr();
let _guard = lock();
let mut ido: c_int = 0;
let mut info: c_int = 0;
let mut matvec = matvec;
let mut x_buf = vec![zero; n];
loop {
unsafe {
cnaupd_c(
&mut ido,
bmat,
n_i32,
which,
nev,
tol,
resid.as_mut_ptr() as *mut __BindgenComplex<f32>,
ncv_i32,
v.as_mut_ptr() as *mut __BindgenComplex<f32>,
ldv,
iparam.as_mut_ptr(),
ipntr.as_mut_ptr(),
workd.as_mut_ptr() as *mut __BindgenComplex<f32>,
workl.as_mut_ptr() as *mut __BindgenComplex<f32>,
lworkl_i32,
rwork.as_mut_ptr(),
&mut info,
);
}
match ido {
-1 | 1 => {
let x_off = (ipntr[0] - 1) as usize;
let y_off = (ipntr[1] - 1) as usize;
debug_assert!(x_off + n <= workd.len() && y_off + n <= workd.len());
x_buf.copy_from_slice(&workd[x_off..x_off + n]);
matvec(&x_buf, &mut workd[y_off..y_off + n]);
}
99 => break,
other => return Err(Error::UnexpectedIdo(other)),
}
}
if info == 1 {
return Err(Error::MaxIterReached {
iters: usize_from_iparam(iparam[2]),
nconv: usize_from_iparam(iparam[4]),
n_matvec: usize_from_iparam(iparam[8]),
});
}
if info != 0 {
return Err(Error::AupdFailed(info));
}
let rvec: c_int = 1;
let howmny = c"A".as_ptr();
let mut select = vec![0i32; ncv];
let mut d = vec![zero; nev_usize];
let sigma = __BindgenComplex {
re: 0.0_f32,
im: 0.0_f32,
};
let mut workev = vec![zero; workev_len];
let mut info_eup: c_int = 0;
unsafe {
cneupd_c(
rvec,
howmny,
select.as_mut_ptr(),
d.as_mut_ptr() as *mut __BindgenComplex<f32>,
v.as_mut_ptr() as *mut __BindgenComplex<f32>,
ldv,
sigma,
workev.as_mut_ptr() as *mut __BindgenComplex<f32>,
bmat,
n_i32,
which,
nev,
tol,
resid.as_mut_ptr() as *mut __BindgenComplex<f32>,
ncv_i32,
v.as_mut_ptr() as *mut __BindgenComplex<f32>,
ldv,
iparam.as_mut_ptr(),
ipntr.as_mut_ptr(),
workd.as_mut_ptr() as *mut __BindgenComplex<f32>,
workl.as_mut_ptr() as *mut __BindgenComplex<f32>,
lworkl_i32,
rwork.as_mut_ptr(),
&mut info_eup,
);
}
if info_eup != 0 {
return Err(Error::EupdFailed(info_eup));
}
let value = d[0];
let mut vector = vec![zero; n];
vector.copy_from_slice(&v[..n]);
Ok(EigSolution {
eigenvalue: value,
eigenvector: vector,
iters: usize_from_iparam(iparam[2]),
nconv: usize_from_iparam(iparam[4]),
n_matvec: usize_from_iparam(iparam[8]),
})
}
fn c_int_from_usize(value: usize) -> Result<c_int, Error> {
c_int::try_from(value).map_err(|_| Error::InvalidParam("value does not fit in c_int"))
}