use crate::DType;
#[cfg(feature = "sparse")]
use numr::algorithm::sparse_linalg::{
ColamdOptions, LuFactors, LuOptions, LuSymbolic, LuWorkspace, colamd, hopcroft_karp,
sparse_lu_cpu_with_workspace_and_metrics, sparse_lu_solve_cpu,
};
#[cfg(feature = "sparse")]
use numr::error::Result;
#[cfg(feature = "sparse")]
use numr::ops::IndexingOps;
#[cfg(feature = "sparse")]
use numr::runtime::Runtime;
#[cfg(feature = "sparse")]
use numr::sparse::{CscData, SparseOps, SparseScaling, SparseStorage};
#[cfg(feature = "sparse")]
use numr::tensor::Tensor;
#[cfg(feature = "sparse")]
use super::direct_solver_config::DirectSolverConfig;
#[cfg(feature = "sparse")]
use super::sparse_utils::dense_to_csr_full;
#[cfg(feature = "sparse")]
use super::symbolic_analysis::compute_lu_symbolic;
#[cfg(feature = "sparse")]
pub struct DirectSparseSolver<R: Runtime<DType = DType>> {
col_perm: Option<Vec<usize>>,
row_perm: Option<Vec<usize>>,
row_perm_tensor: Option<Tensor<R>>,
inv_col_perm_tensor: Option<Tensor<R>>,
symbolic: Option<LuSymbolic>,
factors: Option<LuFactors<R>>,
lu_options: LuOptions,
cached_permuted_csc: Option<CscData<R>>,
gather_row_indices: Option<Tensor<R>>,
gather_col_indices: Option<Tensor<R>>,
workspace: Option<LuWorkspace>,
row_scales: Option<Vec<f64>>,
col_scales: Option<Vec<f64>>,
row_scales_tensor: Option<Tensor<R>>,
col_scales_tensor: Option<Tensor<R>>,
equilibrate: bool,
pivot_growth_threshold: f64,
pub refactor_count: usize,
pub reorder_count: usize,
pub last_pivot_growth: f64,
pub last_small_pivots: usize,
}
#[cfg(feature = "sparse")]
impl<R: Runtime<DType = DType>> DirectSparseSolver<R> {
pub fn new(config: &DirectSolverConfig) -> Self {
Self {
col_perm: None,
row_perm: None,
row_perm_tensor: None,
inv_col_perm_tensor: None,
symbolic: None,
factors: None,
lu_options: LuOptions {
pivot_tolerance: config.pivot_tolerance,
pivot_threshold: config.pivot_threshold,
diagonal_shift: config.diagonal_shift,
check_zeros: true,
},
cached_permuted_csc: None,
gather_row_indices: None,
gather_col_indices: None,
workspace: None,
row_scales: None,
col_scales: None,
row_scales_tensor: None,
col_scales_tensor: None,
equilibrate: config.equilibrate,
pivot_growth_threshold: config.pivot_growth_threshold,
refactor_count: 0,
reorder_count: 0,
last_pivot_growth: 0.0,
last_small_pivots: 0,
}
}
pub fn solve<C>(&mut self, client: &C, m_dense: &Tensor<R>, b: &Tensor<R>) -> Result<Tensor<R>>
where
C: SparseOps<R> + IndexingOps<R> + numr::ops::TensorOps<R> + numr::ops::ScalarOps<R>,
{
if self.symbolic.is_none() {
let csr = dense_to_csr_full(client, m_dense)?;
let csc = csr.to_csc()?;
let n = csc.shape()[0];
self.full_analysis(&csc, n)?;
} else {
let values = client.gather_2d(
m_dense,
self.gather_row_indices
.as_ref()
.expect("gather indices set after full_analysis"),
self.gather_col_indices
.as_ref()
.expect("gather indices set after full_analysis"),
)?;
self.cached_permuted_csc
.as_mut()
.expect("cached CSC set after full_analysis")
.update_values(values)?;
}
let n = self.cached_permuted_csc.as_ref().unwrap().shape()[0];
let symbolic = self.symbolic.as_ref().unwrap();
let workspace = self.workspace.as_mut().unwrap();
let factored_csc = if self.equilibrate {
let csc_ref = self.cached_permuted_csc.as_ref().unwrap();
if let (Some(row_scales), Some(col_scales)) = (&self.row_scales, &self.col_scales) {
let scaled = csc_ref.scale_rows(row_scales)?;
scaled.scale_cols(col_scales)?
} else {
csc_ref.clone()
}
} else {
self.cached_permuted_csc.as_ref().unwrap().clone()
};
let (factors, metrics) = sparse_lu_cpu_with_workspace_and_metrics(
&factored_csc,
symbolic,
&self.lu_options,
workspace,
)?;
self.last_pivot_growth = metrics.pivot_growth;
self.last_small_pivots = metrics.small_pivots;
self.refactor_count += 1;
let solution = self.solve_with_factors(client, &factors, b, n)?;
self.factors = Some(factors);
Ok(solution)
}
fn full_analysis(&mut self, csc: &CscData<R>, n: usize) -> Result<()> {
let col_ptrs: Vec<i64> = csc.col_ptrs().to_vec();
let row_indices: Vec<i64> = csc.row_indices().to_vec();
let colamd_opts = ColamdOptions::default();
let (col_perm, _stats) = colamd(n, n, &col_ptrs, &row_indices, &colamd_opts)?;
let perm_csc = Self::permute_csc_columns(csc, &col_perm, n)?;
let needs_row_perm = !perm_csc.has_full_diagonal();
if needs_row_perm {
let perm_col_ptrs: Vec<i64> = perm_csc.col_ptrs().to_vec();
let perm_row_indices: Vec<i64> = perm_csc.row_indices().to_vec();
let matching = hopcroft_karp(n, n, &perm_col_ptrs, &perm_row_indices)?;
let mut row_perm = vec![0usize; n];
for (col, &row) in matching.col_to_row.iter().enumerate() {
if row >= 0 && (row as usize) < n {
row_perm[col] = row as usize;
} else {
row_perm[col] = col; }
}
self.row_perm = Some(row_perm);
let symbolic = compute_lu_symbolic(n, &perm_col_ptrs, &perm_row_indices);
self.workspace = Some(LuWorkspace::new(n, &symbolic));
self.symbolic = Some(symbolic);
} else {
let perm_col_ptrs: Vec<i64> = perm_csc.col_ptrs().to_vec();
let perm_row_indices: Vec<i64> = perm_csc.row_indices().to_vec();
let symbolic = compute_lu_symbolic(n, &perm_col_ptrs, &perm_row_indices);
self.workspace = Some(LuWorkspace::new(n, &symbolic));
self.symbolic = Some(symbolic);
}
let device = perm_csc.col_ptrs().device();
if self.equilibrate {
let (_scaled, row_scales, col_scales) = perm_csc.equilibrate::<f64>()?;
self.row_scales = Some(row_scales.clone());
self.col_scales = Some(col_scales.clone());
self.row_scales_tensor = Some(Tensor::<R>::from_slice(&row_scales, &[n], device));
self.col_scales_tensor = Some(Tensor::<R>::from_slice(&col_scales, &[n], device));
}
let (gather_rows, gather_cols) = Self::build_gather_indices(&perm_csc, n, &col_perm)?;
self.gather_row_indices = Some(gather_rows);
self.gather_col_indices = Some(gather_cols);
if let Some(row_perm) = &self.row_perm {
let row_perm_i64: Vec<i64> = row_perm.iter().map(|&i| i as i64).collect();
self.row_perm_tensor = Some(Tensor::<R>::from_slice(&row_perm_i64, &[n], device));
}
let mut inv_col_perm = vec![0usize; n];
for i in 0..n {
inv_col_perm[col_perm[i]] = i;
}
let inv_col_perm_i64: Vec<i64> = inv_col_perm.iter().map(|&i| i as i64).collect();
self.inv_col_perm_tensor = Some(Tensor::<R>::from_slice(&inv_col_perm_i64, &[n], device));
self.cached_permuted_csc = Some(perm_csc);
self.col_perm = Some(col_perm);
self.reorder_count += 1;
Ok(())
}
fn build_gather_indices(
perm_csc: &CscData<R>,
n: usize,
col_perm: &[usize],
) -> Result<(Tensor<R>, Tensor<R>)> {
let col_ptrs: Vec<i64> = perm_csc.col_ptrs().to_vec();
let row_indices: Vec<i64> = perm_csc.row_indices().to_vec();
let nnz = row_indices.len();
let mut dense_rows = Vec::with_capacity(nnz);
let mut dense_cols = Vec::with_capacity(nnz);
for perm_col in 0..n {
let orig_col = col_perm[perm_col];
let start = col_ptrs[perm_col] as usize;
let end = col_ptrs[perm_col + 1] as usize;
for &ri in &row_indices[start..end] {
dense_rows.push(ri); dense_cols.push(orig_col as i64); }
}
let device = perm_csc.col_ptrs().device();
let row_tensor = Tensor::<R>::from_slice(&dense_rows, &[nnz], device);
let col_tensor = Tensor::<R>::from_slice(&dense_cols, &[nnz], device);
Ok((row_tensor, col_tensor))
}
fn permute_csc_columns(csc: &CscData<R>, col_perm: &[usize], n: usize) -> Result<CscData<R>> {
let old_col_ptrs: Vec<i64> = csc.col_ptrs().to_vec();
let old_row_indices: Vec<i64> = csc.row_indices().to_vec();
let old_values: Vec<f64> = csc.values().to_vec();
let mut new_col_ptrs = vec![0i64; n + 1];
let mut new_row_indices = Vec::new();
let mut new_values = Vec::new();
for new_col in 0..n {
let old_col = col_perm[new_col];
let start = old_col_ptrs[old_col] as usize;
let end = old_col_ptrs[old_col + 1] as usize;
for (&ri, &val) in old_row_indices[start..end]
.iter()
.zip(&old_values[start..end])
{
new_row_indices.push(ri);
new_values.push(val);
}
new_col_ptrs[new_col + 1] = new_row_indices.len() as i64;
}
let device = csc.col_ptrs().device();
CscData::from_slices(&new_col_ptrs, &new_row_indices, &new_values, [n, n], device)
}
fn solve_with_factors<C>(
&self,
client: &C,
factors: &LuFactors<R>,
b: &Tensor<R>,
n: usize,
) -> Result<Tensor<R>>
where
C: IndexingOps<R> + numr::ops::TensorOps<R> + numr::ops::ScalarOps<R>,
{
let b_shape = b.shape().to_vec();
let b_flat = if b_shape.len() == 2 && b_shape[1] == 1 {
b.reshape(&[n])?
} else {
b.clone()
};
let b_perm = if let Some(row_perm_tensor) = &self.row_perm_tensor {
client.index_select(&b_flat, 0, row_perm_tensor)?
} else {
b_flat
};
let b_scaled = if self.equilibrate {
if let Some(row_scales_tensor) = &self.row_scales_tensor {
client.mul(&b_perm, row_scales_tensor)?
} else {
b_perm
}
} else {
b_perm
};
let z = sparse_lu_solve_cpu(factors, &b_scaled)?;
let z_scaled = if self.equilibrate {
if let Some(col_scales_tensor) = &self.col_scales_tensor {
client.mul(&z, col_scales_tensor)?
} else {
z
}
} else {
z
};
let result =
client.index_select(&z_scaled, 0, self.inv_col_perm_tensor.as_ref().unwrap())?;
if b_shape.len() == 2 && b_shape[1] == 1 {
result.reshape(&[n, 1])
} else {
Ok(result)
}
}
pub fn invalidate(&mut self) {
self.col_perm = None;
self.row_perm = None;
self.row_perm_tensor = None;
self.inv_col_perm_tensor = None;
self.symbolic = None;
self.factors = None;
self.cached_permuted_csc = None;
self.gather_row_indices = None;
self.gather_col_indices = None;
self.workspace = None;
self.row_scales = None;
self.col_scales = None;
self.row_scales_tensor = None;
self.col_scales_tensor = None;
}
pub fn has_symbolic(&self) -> bool {
self.symbolic.is_some()
}
pub fn pivot_growth_unreliable(&self) -> bool {
self.last_pivot_growth > self.pivot_growth_threshold
}
pub fn last_metrics(&self) -> Option<(f64, usize)> {
if self.refactor_count > 0 {
Some((self.last_pivot_growth, self.last_small_pivots))
} else {
None
}
}
}