use crate::error::Result;
use crate::runtime::Runtime;
use crate::sparse::CscData;
use super::types::{LuFactors, LuMetrics, LuOptions, LuSymbolic};
pub trait SparseLuOps<R: Runtime> {
fn sparse_lu(
&self,
a: &CscData<R>,
symbolic: &LuSymbolic,
options: &LuOptions,
) -> Result<LuFactors<R>>;
fn sparse_lu_simple(&self, a: &CscData<R>, options: &LuOptions) -> Result<LuFactors<R>>;
fn sparse_lu_solve(
&self,
factors: &LuFactors<R>,
b: &crate::tensor::Tensor<R>,
) -> Result<crate::tensor::Tensor<R>>;
fn sparse_lu_with_metrics(
&self,
a: &CscData<R>,
symbolic: &LuSymbolic,
options: &LuOptions,
) -> Result<(LuFactors<R>, LuMetrics)>;
}
pub trait SparseLuKernels<R: Runtime> {
fn scatter_column(&self, values: &[f64], row_indices: &[i64], work: &mut [f64]);
fn sparse_axpy(&self, scale: f64, values: &[f64], row_indices: &[i64], work: &mut [f64]);
fn find_pivot(&self, work: &[f64], start: usize, end: usize) -> (usize, f64);
fn gather_and_clear(&self, work: &mut [f64], row_indices: &[i64], output: &mut [f64]);
fn divide_by_pivot(&self, work: &mut [f64], row_indices: &[i64], pivot: f64);
fn swap_rows(&self, work: &mut [f64], perm: &mut [usize], row_a: usize, row_b: usize);
}
pub fn validate_symbolic_pattern(
col_ptrs: &[i64],
_row_indices: &[i64],
symbolic: &LuSymbolic,
) -> Result<()> {
let n = symbolic.n;
if col_ptrs.len() != n + 1 {
return Err(crate::error::Error::ShapeMismatch {
expected: vec![n + 1],
got: vec![col_ptrs.len()],
});
}
let matrix_nnz = col_ptrs[n] as usize;
let symbolic_nnz = symbolic.l_nnz() + symbolic.u_nnz();
if matrix_nnz > symbolic_nnz {
return Err(crate::error::Error::Internal(format!(
"Matrix has more nonzeros ({}) than symbolic structure allows ({})",
matrix_nnz, symbolic_nnz
)));
}
Ok(())
}
pub fn validate_lu_solve_dims(
factors: &LuFactors<impl Runtime>,
b_shape: &[usize],
) -> Result<(usize, usize)> {
let n = factors.row_perm.len();
if b_shape.is_empty() {
return Err(crate::error::Error::Internal(
"Right-hand side must be at least 1D".to_string(),
));
}
if b_shape[0] != n {
return Err(crate::error::Error::ShapeMismatch {
expected: vec![n],
got: vec![b_shape[0]],
});
}
let nrhs = if b_shape.len() == 1 { 1 } else { b_shape[1] };
Ok((n, nrhs))
}