use crate::error::Result;
use crate::runtime::Runtime;
use crate::sparse::CsrData;
use crate::tensor::Tensor;
use super::types::{
IcDecomposition, IcOptions, IluDecomposition, IluFillLevel, IluOptions, IlukDecomposition,
IlukOptions, IlukSymbolic, SymbolicIlu0,
};
pub trait SparseLinAlgAlgorithms<R: Runtime> {
fn ilu0(&self, a: &CsrData<R>, options: IluOptions) -> Result<IluDecomposition<R>>;
fn ic0(&self, a: &CsrData<R>, options: IcOptions) -> Result<IcDecomposition<R>>;
fn sparse_solve_triangular(
&self,
l_or_u: &CsrData<R>,
b: &Tensor<R>,
lower: bool,
unit_diagonal: bool,
) -> Result<Tensor<R>>;
fn iluk_symbolic(&self, a: &CsrData<R>, level: IluFillLevel) -> Result<IlukSymbolic>;
fn iluk_numeric(
&self,
a: &CsrData<R>,
symbolic: &IlukSymbolic,
opts: &IlukOptions,
) -> Result<IlukDecomposition<R>>;
fn iluk(&self, a: &CsrData<R>, opts: IlukOptions) -> Result<IlukDecomposition<R>>;
fn ilu0_symbolic(&self, pattern: &CsrData<R>) -> Result<SymbolicIlu0>;
fn ilu0_numeric(
&self,
a: &CsrData<R>,
symbolic: &SymbolicIlu0,
options: IluOptions,
) -> Result<IluDecomposition<R>>;
}
pub fn validate_square_sparse(shape: [usize; 2]) -> Result<usize> {
let [nrows, ncols] = shape;
if nrows != ncols {
return Err(crate::error::Error::ShapeMismatch {
expected: vec![nrows, nrows],
got: vec![nrows, ncols],
});
}
Ok(nrows)
}
pub fn validate_triangular_solve_dims(
matrix_shape: [usize; 2],
b_shape: &[usize],
) -> Result<(usize, usize)> {
use crate::error::Error;
let n = validate_square_sparse(matrix_shape)?;
if b_shape.is_empty() {
return Err(Error::Internal(
"Right-hand side must be at least 1D".to_string(),
));
}
if b_shape[0] != n {
return Err(Error::ShapeMismatch {
expected: vec![n],
got: vec![b_shape[0]],
});
}
let nrhs = if b_shape.len() == 1 { 1 } else { b_shape[1] };
Ok((n, nrhs))
}