use crate::error::Result;
use crate::runtime::Runtime;
use crate::sparse::CsrData;
use crate::tensor::Tensor;
use super::types::{
BiCgStabOptions, BiCgStabResult, CgOptions, CgResult, CgsOptions, CgsResult, GmresOptions,
GmresResult, JacobiOptions, JacobiResult, LgmresOptions, LgmresResult, MinresOptions,
MinresResult, QmrOptions, QmrResult, SorOptions, SorResult, SparseEigComplexResult,
SparseEigOptions, SparseEigResult, SparseSvdResult, SvdsOptions,
};
pub trait IterativeSolvers<R: Runtime>:
crate::algorithm::sparse_linalg::SparseLinAlgAlgorithms<R> + crate::ops::LinalgOps<R>
{
fn gmres(
&self,
a: &CsrData<R>,
b: &Tensor<R>,
x0: Option<&Tensor<R>>,
options: GmresOptions,
) -> Result<GmresResult<R>>;
fn bicgstab(
&self,
a: &CsrData<R>,
b: &Tensor<R>,
x0: Option<&Tensor<R>>,
options: BiCgStabOptions,
) -> Result<BiCgStabResult<R>>;
fn cg(
&self,
a: &CsrData<R>,
b: &Tensor<R>,
x0: Option<&Tensor<R>>,
options: CgOptions,
) -> Result<CgResult<R>>;
fn minres(
&self,
a: &CsrData<R>,
b: &Tensor<R>,
x0: Option<&Tensor<R>>,
options: MinresOptions,
) -> Result<MinresResult<R>>;
fn cgs(
&self,
a: &CsrData<R>,
b: &Tensor<R>,
x0: Option<&Tensor<R>>,
options: CgsOptions,
) -> Result<CgsResult<R>>;
fn lgmres(
&self,
a: &CsrData<R>,
b: &Tensor<R>,
x0: Option<&Tensor<R>>,
options: LgmresOptions,
) -> Result<LgmresResult<R>>;
fn qmr(
&self,
a: &CsrData<R>,
b: &Tensor<R>,
x0: Option<&Tensor<R>>,
options: QmrOptions,
) -> Result<QmrResult<R>>;
fn jacobi(
&self,
a: &CsrData<R>,
b: &Tensor<R>,
x0: Option<&Tensor<R>>,
options: JacobiOptions,
) -> Result<JacobiResult<R>>;
fn sor(
&self,
a: &CsrData<R>,
b: &Tensor<R>,
x0: Option<&Tensor<R>>,
options: SorOptions,
) -> Result<SorResult<R>>;
fn sparse_eig_symmetric(
&self,
a: &CsrData<R>,
k: usize,
options: SparseEigOptions,
) -> Result<SparseEigResult<R>>;
fn sparse_eig(
&self,
a: &CsrData<R>,
k: usize,
options: SparseEigOptions,
) -> Result<SparseEigComplexResult<R>>;
fn svds(&self, a: &CsrData<R>, k: usize, options: SvdsOptions) -> Result<SparseSvdResult<R>>;
}
pub fn validate_iterative_inputs<R: Runtime>(
a_shape: [usize; 2],
b: &Tensor<R>,
x0: Option<&Tensor<R>>,
) -> Result<usize> {
use crate::error::Error;
let [nrows, ncols] = a_shape;
if nrows != ncols {
return Err(Error::ShapeMismatch {
expected: vec![nrows, nrows],
got: vec![nrows, ncols],
});
}
let n = nrows;
if b.ndim() != 1 {
return Err(Error::Internal(format!(
"Expected 1D right-hand side, got {}D",
b.ndim()
)));
}
if b.numel() != n {
return Err(Error::ShapeMismatch {
expected: vec![n],
got: vec![b.numel()],
});
}
if let Some(x0) = x0 {
if x0.ndim() != 1 {
return Err(Error::Internal(format!(
"Expected 1D initial guess, got {}D",
x0.ndim()
)));
}
if x0.numel() != n {
return Err(Error::ShapeMismatch {
expected: vec![n],
got: vec![x0.numel()],
});
}
}
Ok(n)
}