use std::os::raw::c_int;
use arpack_sys::{dsaupd_c, dseupd_c, ssaupd_c, sseupd_c};
use crate::error::Error;
use crate::lock::lock;
use crate::solution::{usize_from_iparam, EigSolution};
#[derive(Clone, Debug)]
pub struct Options {
pub tol: f64,
pub max_iter: usize,
pub ncv: Option<usize>,
}
impl Default for Options {
fn default() -> Self {
Self {
tol: 0.0,
max_iter: 300,
ncv: None,
}
}
}
pub fn smallest_eigenpair_f64<F>(
n: usize,
matvec: F,
options: &Options,
) -> Result<EigSolution<f64>, Error>
where
F: FnMut(&[f64], &mut [f64]),
{
let nev: c_int = 1;
let nev_usize = nev as usize;
if n < nev_usize + 2 {
return Err(Error::InvalidParam(
"n too small for ARPACK (require n >= nev + 2)",
));
}
let ncv = options.ncv.unwrap_or_else(|| {
(2 * nev_usize + 4).min(n - 1).max(nev_usize + 1)
});
let n_i32 = c_int_from_usize(n, "n")?;
let ncv_i32 = c_int_from_usize(ncv, "ncv")?;
let max_iter_i32 = c_int_from_usize(options.max_iter, "max_iter")?;
if !(nev > 0 && nev < ncv_i32 && ncv_i32 < n_i32) {
return Err(Error::InvalidParam("require 0 < nev < ncv < n"));
}
if max_iter_i32 <= 0 {
return Err(Error::InvalidParam("max_iter must be positive"));
}
let v_len = n.checked_mul(ncv).ok_or(Error::InvalidParam(
"n * ncv overflows usize",
))?;
let workd_len = n.checked_mul(3).ok_or(Error::InvalidParam(
"3 * n overflows usize",
))?;
let lworkl = ncv
.checked_add(8)
.and_then(|s| ncv.checked_mul(s))
.ok_or(Error::InvalidParam(
"ncv * (ncv + 8) overflows usize",
))?;
let lworkl_i32 = c_int_from_usize(lworkl, "lworkl")?;
let mut resid = vec![0.0f64; n];
let mut v = vec![0.0f64; v_len];
let ldv = n_i32;
let mut iparam = [0i32; 11];
iparam[0] = 1; iparam[2] = max_iter_i32;
iparam[3] = 1; iparam[6] = 1; let mut ipntr = [0i32; 11];
let mut workd = vec![0.0f64; workd_len];
let mut workl = vec![0.0f64; lworkl];
let bmat = c"I".as_ptr();
let which = c"SA".as_ptr();
let _guard = lock();
let mut ido: c_int = 0;
let mut info: c_int = 0;
let mut matvec = matvec;
let mut x_buf = vec![0.0f64; n];
loop {
unsafe {
dsaupd_c(
&mut ido,
bmat,
n_i32,
which,
nev,
options.tol,
resid.as_mut_ptr(),
ncv_i32,
v.as_mut_ptr(),
ldv,
iparam.as_mut_ptr(),
ipntr.as_mut_ptr(),
workd.as_mut_ptr(),
workl.as_mut_ptr(),
lworkl_i32,
&mut info,
);
}
match ido {
-1 | 1 => {
let x_off = (ipntr[0] - 1) as usize;
let y_off = (ipntr[1] - 1) as usize;
debug_assert!(x_off + n <= workd.len() && y_off + n <= workd.len());
x_buf.copy_from_slice(&workd[x_off..x_off + n]);
matvec(&x_buf, &mut workd[y_off..y_off + n]);
}
99 => break,
other => return Err(Error::UnexpectedIdo(other)),
}
}
if info == 1 {
return Err(Error::MaxIterReached {
iters: usize_from_iparam(iparam[2]),
nconv: usize_from_iparam(iparam[4]),
n_matvec: usize_from_iparam(iparam[8]),
});
}
if info != 0 {
return Err(Error::AupdFailed(info));
}
let rvec: c_int = 1;
let howmny = c"A".as_ptr();
let mut select = vec![0i32; ncv];
let mut d = vec![0.0f64; nev as usize];
let sigma = 0.0f64;
let mut info_eup: c_int = 0;
unsafe {
dseupd_c(
rvec,
howmny,
select.as_mut_ptr(),
d.as_mut_ptr(),
v.as_mut_ptr(),
ldv,
sigma,
bmat,
n_i32,
which,
nev,
options.tol,
resid.as_mut_ptr(),
ncv_i32,
v.as_mut_ptr(),
ldv,
iparam.as_mut_ptr(),
ipntr.as_mut_ptr(),
workd.as_mut_ptr(),
workl.as_mut_ptr(),
lworkl_i32,
&mut info_eup,
);
}
if info_eup != 0 {
return Err(Error::EupdFailed(info_eup));
}
let value = d[0];
let mut vector = vec![0.0f64; n];
vector.copy_from_slice(&v[..n]);
Ok(EigSolution {
eigenvalue: value,
eigenvector: vector,
iters: usize_from_iparam(iparam[2]),
nconv: usize_from_iparam(iparam[4]),
n_matvec: usize_from_iparam(iparam[8]),
})
}
pub fn smallest_eigenpair_f32<F>(
n: usize,
matvec: F,
options: &Options,
) -> Result<EigSolution<f32>, Error>
where
F: FnMut(&[f32], &mut [f32]),
{
let nev: c_int = 1;
let nev_usize = nev as usize;
if n < nev_usize + 2 {
return Err(Error::InvalidParam(
"n too small for ARPACK (require n >= nev + 2)",
));
}
let ncv = options
.ncv
.unwrap_or_else(|| (2 * nev_usize + 4).min(n - 1).max(nev_usize + 1));
let n_i32 = c_int_from_usize(n, "n")?;
let ncv_i32 = c_int_from_usize(ncv, "ncv")?;
let max_iter_i32 = c_int_from_usize(options.max_iter, "max_iter")?;
if !(nev > 0 && nev < ncv_i32 && ncv_i32 < n_i32) {
return Err(Error::InvalidParam("require 0 < nev < ncv < n"));
}
if max_iter_i32 <= 0 {
return Err(Error::InvalidParam("max_iter must be positive"));
}
let v_len = n
.checked_mul(ncv)
.ok_or(Error::InvalidParam("n * ncv overflows usize"))?;
let workd_len = n
.checked_mul(3)
.ok_or(Error::InvalidParam("3 * n overflows usize"))?;
let lworkl = ncv
.checked_add(8)
.and_then(|s| ncv.checked_mul(s))
.ok_or(Error::InvalidParam("ncv * (ncv + 8) overflows usize"))?;
let lworkl_i32 = c_int_from_usize(lworkl, "lworkl")?;
let tol = options.tol as f32;
let mut resid = vec![0.0f32; n];
let mut v = vec![0.0f32; v_len];
let ldv = n_i32;
let mut iparam = [0i32; 11];
iparam[0] = 1;
iparam[2] = max_iter_i32;
iparam[3] = 1;
iparam[6] = 1;
let mut ipntr = [0i32; 11];
let mut workd = vec![0.0f32; workd_len];
let mut workl = vec![0.0f32; lworkl];
let bmat = c"I".as_ptr();
let which = c"SA".as_ptr();
let _guard = lock();
let mut ido: c_int = 0;
let mut info: c_int = 0;
let mut matvec = matvec;
let mut x_buf = vec![0.0f32; n];
loop {
unsafe {
ssaupd_c(
&mut ido,
bmat,
n_i32,
which,
nev,
tol,
resid.as_mut_ptr(),
ncv_i32,
v.as_mut_ptr(),
ldv,
iparam.as_mut_ptr(),
ipntr.as_mut_ptr(),
workd.as_mut_ptr(),
workl.as_mut_ptr(),
lworkl_i32,
&mut info,
);
}
match ido {
-1 | 1 => {
let x_off = (ipntr[0] - 1) as usize;
let y_off = (ipntr[1] - 1) as usize;
debug_assert!(x_off + n <= workd.len() && y_off + n <= workd.len());
x_buf.copy_from_slice(&workd[x_off..x_off + n]);
matvec(&x_buf, &mut workd[y_off..y_off + n]);
}
99 => break,
other => return Err(Error::UnexpectedIdo(other)),
}
}
if info == 1 {
return Err(Error::MaxIterReached {
iters: usize_from_iparam(iparam[2]),
nconv: usize_from_iparam(iparam[4]),
n_matvec: usize_from_iparam(iparam[8]),
});
}
if info != 0 {
return Err(Error::AupdFailed(info));
}
let rvec: c_int = 1;
let howmny = c"A".as_ptr();
let mut select = vec![0i32; ncv];
let mut d = vec![0.0f32; nev as usize];
let sigma = 0.0f32;
let mut info_eup: c_int = 0;
unsafe {
sseupd_c(
rvec,
howmny,
select.as_mut_ptr(),
d.as_mut_ptr(),
v.as_mut_ptr(),
ldv,
sigma,
bmat,
n_i32,
which,
nev,
tol,
resid.as_mut_ptr(),
ncv_i32,
v.as_mut_ptr(),
ldv,
iparam.as_mut_ptr(),
ipntr.as_mut_ptr(),
workd.as_mut_ptr(),
workl.as_mut_ptr(),
lworkl_i32,
&mut info_eup,
);
}
if info_eup != 0 {
return Err(Error::EupdFailed(info_eup));
}
let value = d[0];
let mut vector = vec![0.0f32; n];
vector.copy_from_slice(&v[..n]);
Ok(EigSolution {
eigenvalue: value,
eigenvector: vector,
iters: usize_from_iparam(iparam[2]),
nconv: usize_from_iparam(iparam[4]),
n_matvec: usize_from_iparam(iparam[8]),
})
}
fn c_int_from_usize(value: usize, name: &'static str) -> Result<c_int, Error> {
c_int::try_from(value).map_err(|_| {
let _ = name; Error::InvalidParam("value does not fit in c_int")
})
}