use crate::error::KError;
use crate::solver::MonitorCallback;
use crate::matrix::sparse::CsrMatrix;
use crate::parallel::{Comm, UniverseComm, contiguous_partition};
use crate::solver::api::Solver;
use crate::solver::legacy::LinearSolver;
use crate::utils::convergence::{ConvergedReason, SolveStats};
use faer::MatMut;
use faer::prelude::*;
use std::cmp::Ordering;
use std::collections::HashMap;
use mpi::collective::CommunicatorCollectives;
#[cfg(feature = "mpi")]
use mpi::raw::AsRaw;
#[cfg(feature = "mpi")]
use std::ffi::c_void;
#[cfg(feature = "mpi")]
use std::mem::MaybeUninit;
#[cfg(feature = "logging")]
use crate::utils::profiling::StageGuard;
#[allow(dead_code)]
fn validate_local_csr(m: &CsrMatrix<f64>) -> Result<(), KError> {
let rp = m.row_ptr();
let cj = m.col_idx();
let vv = m.values();
if rp.len() != m.nrows() + 1 {
return Err(KError::InvalidInput(format!(
"CSR row_ptr length {} != nrows+1 = {}",
rp.len(),
m.nrows() + 1
)));
}
if rp.first().copied() != Some(0) {
return Err(KError::InvalidInput("CSR row_ptr[0] must be 0".into()));
}
for k in 0..m.nrows() {
if rp[k] > rp[k + 1] {
return Err(KError::InvalidInput(format!(
"CSR row_ptr not nondecreasing at row {k}"
)));
}
}
let nnz = *rp.last().unwrap();
if nnz != cj.len() || nnz != vv.len() {
return Err(KError::InvalidInput(format!(
"CSR nnz mismatch: row_ptr last={}, col_idx={}, values={}",
nnz,
cj.len(),
vv.len()
)));
}
let ncols = m.ncols();
for i in 0..m.nrows() {
for p in rp[i]..rp[i + 1] {
if cj[p] >= ncols {
return Err(KError::InvalidInput(format!(
"CSR col index {} out of range (ncols={}) at local row {} pos {}",
cj[p], ncols, i, p
)));
}
}
}
Ok(())
}
#[derive(Debug)]
pub struct ProcessGrid {
pub prows: usize,
pub pcols: usize,
pub my_prow: usize,
pub my_pcol: usize,
pub my_rank: usize,
pub total_procs: usize,
}
impl Clone for ProcessGrid {
fn clone(&self) -> Self {
Self {
prows: self.prows,
pcols: self.pcols,
my_prow: self.my_prow,
my_pcol: self.my_pcol,
my_rank: self.my_rank,
total_procs: self.total_procs,
}
}
}
impl ProcessGrid {
pub fn new_auto(comm: &UniverseComm) -> Result<Self, KError> {
let total_procs = comm.size();
if total_procs == 0 {
return Err(KError::InvalidInput(
"Process grid requires at least one process".into(),
));
}
let (prows, pcols) = Self::determine_optimal_grid(total_procs);
Self::new_with_dims(comm, prows, pcols)
}
pub fn new_with_dims(comm: &UniverseComm, prows: usize, pcols: usize) -> Result<Self, KError> {
let total_procs = comm.size();
let my_rank = comm.rank();
if prows * pcols != total_procs {
return Err(KError::InvalidInput(format!(
"Process grid {prows}x{pcols} doesn't match MPI size {total_procs}"
)));
}
let my_prow = my_rank / pcols;
let my_pcol = my_rank % pcols;
Ok(ProcessGrid {
prows,
pcols,
my_prow,
my_pcol,
my_rank,
total_procs,
})
}
#[allow(dead_code)]
fn determine_optimal_grid(size: usize) -> (usize, usize) {
let mut prows = (size as f64).sqrt().floor() as usize;
while prows > 0 && size % prows != 0 {
prows -= 1;
}
let pcols = size / prows;
(prows, pcols)
}
pub fn rank_to_coords(&self, rank: usize) -> (usize, usize) {
(rank / self.pcols, rank % self.pcols)
}
pub fn coords_to_rank(&self, prow: usize, pcol: usize) -> usize {
prow * self.pcols + pcol
}
pub fn owns_global_row(&self, global_row: usize, block_size: usize) -> bool {
let block_row = global_row / block_size;
block_row % self.prows == self.my_prow
}
pub fn owns_global_col(&self, global_col: usize, block_size: usize) -> bool {
let block_col = global_col / block_size;
block_col % self.pcols == self.my_pcol
}
}
#[cfg(feature = "superlu3d")]
#[derive(Debug, Clone)]
pub struct ProcessGrid3D {
pub prows: usize,
pub pcols: usize,
pub pdepth: usize,
pub my_prow: usize,
pub my_pcol: usize,
pub my_pdepth: usize,
pub my_rank: usize,
pub total_procs: usize,
}
#[cfg(feature = "superlu3d")]
impl ProcessGrid3D {
pub fn from_2d_with_depth(g2d: &ProcessGrid, depth: usize) -> Result<Self, KError> {
let total = g2d.total_procs;
if depth == 0 || total % depth != 0 {
return Err(KError::InvalidInput(format!("invalid 3D depth {depth}")));
}
let layer_size = total / depth;
let my_layer = g2d.my_rank / layer_size;
let my_inlayer = g2d.my_rank % layer_size;
let prow = my_inlayer / g2d.pcols;
let pcol = my_inlayer % g2d.pcols;
Ok(Self {
prows: g2d.prows,
pcols: g2d.pcols,
pdepth: depth,
my_prow: prow,
my_pcol: pcol,
my_pdepth: my_layer,
my_rank: g2d.my_rank,
total_procs: total,
})
}
#[inline]
pub fn coords_to_rank(&self, prow: usize, pcol: usize, pdepth: usize) -> usize {
let layer_size = self.prows * self.pcols;
pdepth * layer_size + (prow * self.pcols + pcol)
}
}
#[derive(Debug, Clone)]
pub struct BlockCyclicDistribution {
pub grid: ProcessGrid,
pub row_block_size: usize,
pub col_block_size: usize,
pub global_rows: usize,
pub global_cols: usize,
pub local_rows: usize,
pub local_cols: usize,
}
impl BlockCyclicDistribution {
pub fn new(
grid: ProcessGrid,
global_rows: usize,
global_cols: usize,
row_block_size: usize,
col_block_size: usize,
) -> Self {
let local_rows =
Self::calculate_local_dimension(global_rows, row_block_size, grid.prows, grid.my_prow);
let local_cols =
Self::calculate_local_dimension(global_cols, col_block_size, grid.pcols, grid.my_pcol);
Self {
grid,
row_block_size,
col_block_size,
global_rows,
global_cols,
local_rows,
local_cols,
}
}
pub fn n_row_blocks(&self, block_size: usize) -> usize {
self.global_rows.div_ceil(block_size)
}
pub fn n_col_blocks(&self, block_size: usize) -> usize {
self.global_cols.div_ceil(block_size)
}
pub fn owner_rank_of_block(&self, brow: usize, bcol: usize) -> usize {
let prow = brow % self.grid.prows;
let pcol = bcol % self.grid.pcols;
self.grid.coords_to_rank(prow, pcol)
}
pub fn owner_rank_of_diag_block(&self, k: usize) -> usize {
self.owner_rank_of_block(k, k)
}
pub fn local_rows_for_rank(&self, rank: usize) -> usize {
let (prow, _) = self.grid.rank_to_coords(rank);
Self::calculate_local_dimension(
self.global_rows,
self.row_block_size,
self.grid.prows,
prow,
)
}
pub fn local_cols_for_rank(&self, rank: usize) -> usize {
let (_, pcol) = self.grid.rank_to_coords(rank);
Self::calculate_local_dimension(
self.global_cols,
self.col_block_size,
self.grid.pcols,
pcol,
)
}
fn calculate_local_dimension(
global_dim: usize,
block_size: usize,
proc_dim: usize,
my_proc: usize,
) -> usize {
if global_dim == 0 {
return 0;
}
let num_blocks = global_dim.div_ceil(block_size);
let blocks_per_proc = num_blocks / proc_dim;
let extra_blocks = num_blocks % proc_dim;
let my_blocks = blocks_per_proc + if my_proc < extra_blocks { 1 } else { 0 };
if my_blocks == 0 {
return 0;
}
let my_start_block = my_proc * blocks_per_proc + std::cmp::min(my_proc, extra_blocks);
let my_end_block = my_start_block + my_blocks;
let my_start_idx = my_start_block * block_size;
let my_end_idx = std::cmp::min((my_end_block - 1) * block_size + block_size, global_dim);
my_end_idx - my_start_idx
}
pub fn global_to_local_row(&self, global_row: usize) -> Option<usize> {
let block_id = global_row / self.row_block_size;
let block_offset = global_row % self.row_block_size;
let owner_proc = block_id % self.grid.prows;
if owner_proc == self.grid.my_prow {
let local_block_id = block_id / self.grid.prows;
Some(local_block_id * self.row_block_size + block_offset)
} else {
None
}
}
pub fn global_to_local_col(&self, global_col: usize) -> Option<usize> {
let block_id = global_col / self.col_block_size;
let block_offset = global_col % self.col_block_size;
let owner_proc = block_id % self.grid.pcols;
if owner_proc == self.grid.my_pcol {
let local_block_id = block_id / self.grid.pcols;
Some(local_block_id * self.col_block_size + block_offset)
} else {
None
}
}
pub fn local_to_global_row(&self, local_row: usize) -> usize {
let local_block_id = local_row / self.row_block_size;
let block_offset = local_row % self.row_block_size;
let global_block_id = local_block_id * self.grid.prows + self.grid.my_prow;
global_block_id * self.row_block_size + block_offset
}
pub fn local_to_global_col(&self, local_col: usize) -> usize {
let local_block_id = local_col / self.col_block_size;
let block_offset = local_col % self.col_block_size;
let global_block_id = local_block_id * self.grid.pcols + self.grid.my_pcol;
global_block_id * self.col_block_size + block_offset
}
#[inline]
pub fn row_block_of(&self, global_row: usize) -> usize {
global_row / self.row_block_size
}
#[inline]
pub fn col_block_of(&self, global_col: usize) -> usize {
global_col / self.col_block_size
}
#[inline]
pub fn owner_coords_of(&self, global_row: usize, global_col: usize) -> (usize, usize) {
let br = self.row_block_of(global_row) % self.grid.prows;
let bc = self.col_block_of(global_col) % self.grid.pcols;
(br, bc)
}
#[inline]
pub fn owner_of(&self, global_row: usize, global_col: usize) -> usize {
let (pr, pc) = self.owner_coords_of(global_row, global_col);
self.grid.coords_to_rank(pr, pc)
}
#[inline]
pub fn local_row_from_global(&self, global_row: usize) -> Option<usize> {
if !self.owns_global_row(global_row, self.row_block_size) {
return None;
}
let block_id_g = self.row_block_of(global_row);
let block_id_l = block_id_g / self.grid.prows;
let offset_in_block = global_row % self.row_block_size;
Some(block_id_l * self.row_block_size + offset_in_block)
}
#[inline]
pub fn local_col_from_global(&self, global_col: usize) -> Option<usize> {
if !self.owns_global_col(global_col, self.col_block_size) {
return None;
}
let block_id_g = self.col_block_of(global_col);
let block_id_l = block_id_g / self.grid.pcols;
let offset_in_block = global_col % self.col_block_size;
Some(block_id_l * self.col_block_size + offset_in_block)
}
#[inline]
pub fn owns_global_row(&self, global_row: usize, block_size: usize) -> bool {
self.grid.owns_global_row(global_row, block_size)
}
#[inline]
pub fn owns_global_col(&self, global_col: usize, block_size: usize) -> bool {
self.grid.owns_global_col(global_col, block_size)
}
}
#[cfg(test)]
mod dist_tests {
use super::*;
fn make_grid(total: usize, prows: usize, pcols: usize, my_rank: usize) -> ProcessGrid {
ProcessGrid {
prows,
pcols,
my_prow: my_rank / pcols,
my_pcol: my_rank % pcols,
my_rank,
total_procs: total,
}
}
#[test]
fn roundtrip_owned_indices() {
let cases = [
(5, 7, (2, 2), (2, 3)), (0, 0, (1, 1), (2, 2)), (8, 8, (2, 3), (3, 2)), (17, 9, (3, 1), (4, 4)), ];
for &(nr, nc, (pr, pc), (br, bc)) in &cases {
let total = pr * pc;
for rank in 0..total {
let grid = make_grid(total, pr, pc, rank);
let dist = BlockCyclicDistribution::new(grid.clone(), nr, nc, br.max(1), bc.max(1));
for i in 0..nr {
let owns = dist.owns_global_row(i, dist.row_block_size);
match (owns, dist.local_row_from_global(i)) {
(true, Some(loc)) => {
let back = dist.local_to_global_row(loc);
assert_eq!(back, i, "row round-trip failed (i={i}, rank={rank:?})");
}
(false, None) => {}
other => panic!("inconsistent row ownership: {other:?}"),
}
}
for j in 0..nc {
let owns = dist.owns_global_col(j, dist.col_block_size);
match (owns, dist.local_col_from_global(j)) {
(true, Some(loc)) => {
let back = dist.local_to_global_col(loc);
assert_eq!(back, j, "col round-trip failed (j={j}, rank={rank:?})");
}
(false, None) => {}
other => panic!("inconsistent col ownership: {other:?}"),
}
}
}
}
}
#[test]
fn owner_rank_agrees_with_coords() {
let grid = ProcessGrid {
prows: 2,
pcols: 3,
my_prow: 0,
my_pcol: 0,
my_rank: 0,
total_procs: 6,
};
let dist = BlockCyclicDistribution::new(grid, 10, 11, 4, 3);
for i in 0..10 {
for j in 0..11 {
let (pr, pc) = dist.owner_coords_of(i, j);
assert_eq!(dist.owner_of(i, j), pr * dist.grid.pcols + pc);
}
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PivotingStrategy {
Dynamic,
Static,
ThresholdWithFallback,
}
#[derive(Debug, Clone)]
pub struct Panel {
pub width: usize,
pub height: usize,
pub data: Vec<f64>,
pub row_indices: Vec<usize>,
pub col_start: usize,
}
impl Panel {
pub fn from_sparse_columns(
matrix: &CsrMatrix<f64>,
col_start: usize,
col_end: usize,
row_indices: Vec<usize>,
) -> Self {
let width = col_end - col_start;
let height = row_indices.len();
let mut data = vec![0.0; width * height];
let row_ptrs = matrix.row_ptr();
let col_indices = matrix.col_idx();
let values = matrix.values();
for (local_row, &global_row) in row_indices.iter().enumerate() {
let start = row_ptrs[global_row];
let end = row_ptrs[global_row + 1];
for idx in start..end {
let col = col_indices[idx];
if col >= col_start && col < col_end {
let local_col = col - col_start;
data[local_col * height + local_row] = values[idx];
}
}
}
debug_assert_eq!(width, col_end - col_start);
debug_assert_eq!(height, row_indices.len());
Self {
width,
height,
data,
row_indices,
col_start,
}
}
pub fn as_faer_mut(&mut self) -> MatMut<'_, f64> {
MatMut::from_column_major_slice_mut(&mut self.data, self.height, self.width)
}
pub fn as_faer(&self) -> faer::MatRef<'_, f64> {
faer::MatRef::from_column_major_slice(&self.data, self.height, self.width)
}
pub fn factorize_lu(
&mut self,
threshold: f64,
pivot_strategy: PivotingStrategy,
) -> Result<PanelFactorization, KError> {
let m = self.height;
let n = self.width;
let mut tiny_pivots_replaced = 0usize;
let kb = core::cmp::min(64, n.max(1));
let mut row_perm: Vec<usize> = (0..m).collect();
let mut num_row_swaps = 0usize;
let mut is_singular = false;
let mut pivot_strategy = pivot_strategy;
let mut a = self.as_faer_mut();
let mut j = 0;
while j < n {
let jb = core::cmp::min(kb, n - j);
for col in 0..jb {
let gcol = j + col;
let mut piv = gcol;
let mut max_val = a[(gcol, gcol)].abs();
match pivot_strategy {
PivotingStrategy::Static => {}
PivotingStrategy::Dynamic => {
for r in gcol..m {
let val = a[(r, gcol)].abs();
if val > max_val {
max_val = val;
piv = r;
}
}
}
PivotingStrategy::ThresholdWithFallback => {
if max_val < threshold {
pivot_strategy = PivotingStrategy::Dynamic;
for r in gcol..m {
let val = a[(r, gcol)].abs();
if val > max_val {
max_val = val;
piv = r;
}
}
}
}
}
if max_val < threshold {
let old = a[(gcol, gcol)];
tiny_pivots_replaced += 1;
is_singular = true;
a[(gcol, gcol)] = if old == 0.0 {
threshold
} else {
threshold.copysign(old)
};
}
if pivot_strategy != PivotingStrategy::Static && piv != gcol {
for c in j..n {
let t = a[(gcol, c)];
a[(gcol, c)] = a[(piv, c)];
a[(piv, c)] = t;
}
row_perm.swap(gcol, piv);
num_row_swaps += 1;
}
let diag = a[(gcol, gcol)];
if diag != 0.0 {
for r in (gcol + 1)..m {
a[(r, gcol)] /= diag;
}
}
}
let right_cols = n - (j + jb);
if n > j {
#[allow(unused_mut)]
let mut sub = a.rb_mut().submatrix_mut(j, j, m - j, n - j);
#[allow(unused_mut)]
let (mut l_block_and_l21, mut right) = sub.split_at_col_mut(jb);
if right_cols > 0 {
faer::linalg::triangular_solve::solve_unit_lower_triangular_in_place(
l_block_and_l21.rb(),
right.rb_mut(),
faer::Par::Seq,
);
}
if (m > j + jb) && (n > j + jb) {
#[allow(unused_mut)]
let (_, mut l21) = l_block_and_l21.split_at_row_mut(jb);
#[allow(unused_mut)]
let (mut u12, mut trailing) = right.split_at_row_mut(jb);
faer::linalg::matmul::matmul(
trailing.rb_mut(),
faer::Accum::Add,
l21.rb(),
u12.rb(),
-1.0,
faer::Par::Seq,
);
}
}
j += jb;
}
Ok(PanelFactorization {
row_permutation: row_perm,
pivot_strategy,
diagonal_threshold: threshold,
num_row_swaps,
is_singular,
tiny_pivots_replaced,
})
}
}
#[derive(Debug, Clone)]
pub struct PanelFactorization {
pub row_permutation: Vec<usize>,
pub pivot_strategy: PivotingStrategy,
pub diagonal_threshold: f64,
pub num_row_swaps: usize,
pub is_singular: bool,
pub tiny_pivots_replaced: usize,
}
#[derive(Debug, Clone)]
pub struct NumericFactorization {
pub n: usize,
pub nnz: usize,
pub panels: Vec<Panel>,
pub panel_factors: Vec<PanelFactorization>,
pub global_row_perm: Vec<usize>,
pub global_col_perm: Vec<usize>,
pub row_scale: Vec<f64>,
pub col_scale: Vec<f64>,
pub pivot_strategy: PivotingStrategy,
pub pivot_threshold: f64,
pub replaced_tiny_pivots: bool,
pub factor_stats: FactorizationStats,
pub l_block_graph: Vec<Vec<usize>>,
pub u_block_graph: Vec<Vec<usize>>,
}
#[derive(Debug, Clone)]
pub struct FactorizationStats {
pub num_panels: usize,
pub total_row_swaps: usize,
pub tiny_pivots_replaced: usize,
pub max_pivot_growth: f64,
pub condition_estimate: Option<f64>,
pub memory_usage: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CommPattern {
PointToPoint,
BinaryTree,
Ring,
Butterfly,
}
#[derive(Debug)]
pub struct CommRequest {
pub request_id: usize,
pub source_rank: usize,
pub dest_rank: usize,
pub tag: usize,
pub comm_type: CommType,
pub buffer_size: usize,
pub completed: bool,
pub error: Option<String>,
}
#[derive(Debug)]
struct PendingComm {
meta: CommRequest,
#[cfg(feature = "mpi")]
handle: mpi::ffi::MPI_Request,
buffer: Option<Vec<f64>>,
}
impl CommRequest {
pub fn new(
request_id: usize,
source_rank: usize,
dest_rank: usize,
tag: usize,
comm_type: CommType,
buffer_size: usize,
) -> Self {
Self {
request_id,
source_rank,
dest_rank,
tag,
comm_type,
buffer_size,
completed: false,
error: None,
}
}
pub fn is_completed(&self) -> bool {
self.completed
}
pub fn mark_completed(&mut self) {
self.completed = true;
}
pub fn set_error(&mut self, error: String) {
self.error = Some(error);
self.completed = true;
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum CommType {
Send,
Recv,
Broadcast,
Reduce,
AllReduce,
}
#[derive(Debug)]
pub struct TriangularSolveData {
pub local_solution_blocks: Vec<Vec<f64>>,
pub comm_buffer: Vec<f64>,
pending_requests: Vec<PendingComm>,
pub block_owners: Vec<usize>,
pub block_sizes: Vec<usize>,
pub local_l_factors: Vec<Panel>,
pub local_u_factors: Vec<Panel>,
pub dependency_graph: Vec<Vec<usize>>,
}
impl TriangularSolveData {
pub fn new(
n: usize,
block_size: usize,
distribution: &BlockCyclicDistribution,
numeric_factor: &NumericFactorization,
deps: Vec<Vec<usize>>,
) -> Self {
let num_blocks = distribution.n_col_blocks(block_size);
let mut block_sizes = vec![block_size; num_blocks];
if num_blocks > 0 {
let rem = n % block_size;
if rem != 0 {
block_sizes[num_blocks - 1] = rem;
}
}
let mut block_owners = vec![0; num_blocks];
for k in 0..num_blocks {
block_owners[k] = distribution.owner_rank_of_diag_block(k);
}
let mut local_solution_blocks = Vec::new();
for k in 0..num_blocks {
if block_owners[k] == distribution.grid.my_rank {
local_solution_blocks.push(vec![0.0; block_sizes[k]]);
}
}
let dependency_graph = if deps.len() == num_blocks {
deps
} else {
vec![Vec::new(); num_blocks]
};
Self {
local_solution_blocks,
comm_buffer: vec![0.0; block_size * distribution.grid.total_procs],
pending_requests: Vec::new(),
block_owners,
block_sizes,
local_l_factors: numeric_factor.panels.clone(),
local_u_factors: numeric_factor.panels.clone(), dependency_graph,
}
}
pub fn isend(
&mut self,
data: &[f64],
dest_rank: usize,
tag: usize,
request_id: usize,
comm: &UniverseComm,
) -> Result<(), KError> {
let mut request = CommRequest::new(
request_id,
comm.rank(), dest_rank,
tag,
CommType::Send,
data.len(),
);
if comm.size() <= 1 {
request.mark_completed();
self.pending_requests.push(PendingComm {
meta: request,
#[cfg(feature = "mpi")]
handle: unsafe { mpi::ffi::RSMPI_REQUEST_NULL },
buffer: None,
});
return Ok(());
}
#[cfg(feature = "mpi")]
if let Some(world) = comm.as_mpi() {
let buffer = data.to_vec();
let mut handle = MaybeUninit::<mpi::ffi::MPI_Request>::uninit();
let rc = unsafe {
mpi::ffi::MPI_Isend(
buffer.as_ptr() as *const c_void,
buffer.len() as i32,
mpi::ffi::RSMPI_DOUBLE,
dest_rank as i32,
tag as i32,
world.as_raw(),
handle.as_mut_ptr(),
)
};
if rc != 0 {
request.set_error(format!("MPI_Isend failed with code {rc}"));
self.pending_requests.push(PendingComm {
meta: request,
handle: unsafe { mpi::ffi::RSMPI_REQUEST_NULL },
buffer: None,
});
return Err(KError::SolveError(format!(
"MPI_Isend failed with code {rc}"
)));
}
let handle = unsafe { handle.assume_init() };
self.pending_requests.push(PendingComm {
meta: request,
handle,
buffer: Some(buffer),
});
return Ok(());
}
Err(KError::InvalidInput(
"MPI communicator required for nonblocking send".into(),
))
}
pub fn irecv(
&mut self,
buffer_size: usize,
source_rank: usize,
tag: usize,
request_id: usize,
comm: &UniverseComm,
) -> Result<(), KError> {
let mut request = CommRequest::new(
request_id,
source_rank,
comm.rank(), tag,
CommType::Recv,
buffer_size,
);
if comm.size() <= 1 {
request.mark_completed();
self.pending_requests.push(PendingComm {
meta: request,
#[cfg(feature = "mpi")]
handle: unsafe { mpi::ffi::RSMPI_REQUEST_NULL },
buffer: None,
});
return Ok(());
}
#[cfg(feature = "mpi")]
if let Some(world) = comm.as_mpi() {
let mut buffer = vec![0.0f64; buffer_size];
let mut handle = MaybeUninit::<mpi::ffi::MPI_Request>::uninit();
let rc = unsafe {
mpi::ffi::MPI_Irecv(
buffer.as_mut_ptr() as *mut c_void,
buffer.len() as i32,
mpi::ffi::RSMPI_DOUBLE,
source_rank as i32,
tag as i32,
world.as_raw(),
handle.as_mut_ptr(),
)
};
if rc != 0 {
request.set_error(format!("MPI_Irecv failed with code {rc}"));
self.pending_requests.push(PendingComm {
meta: request,
handle: unsafe { mpi::ffi::RSMPI_REQUEST_NULL },
buffer: None,
});
return Err(KError::SolveError(format!(
"MPI_Irecv failed with code {rc}"
)));
}
let handle = unsafe { handle.assume_init() };
self.pending_requests.push(PendingComm {
meta: request,
handle,
buffer: Some(buffer),
});
return Ok(());
}
Err(KError::InvalidInput(
"MPI communicator required for nonblocking receive".into(),
))
}
pub fn wait(
&mut self,
request_id: usize,
target: Option<&mut [f64]>,
) -> Result<(), KError> {
if let Some(index) = self
.pending_requests
.iter()
.position(|req| req.meta.request_id == request_id)
{
let mut req = self.pending_requests.remove(index);
#[cfg(feature = "mpi")]
{
if req.meta.completed {
return Ok(());
}
let rc = unsafe { mpi::ffi::MPI_Wait(&mut req.handle, mpi::ffi::RSMPI_STATUS_IGNORE) };
if rc != 0 {
return Err(KError::SolveError(format!(
"MPI_Wait failed with code {rc}"
)));
}
}
if req.meta.comm_type == CommType::Recv {
if let Some(buffer) = req.buffer.take() {
self.comm_buffer.resize(buffer.len(), 0.0);
self.comm_buffer.copy_from_slice(&buffer);
if let Some(target) = target {
target.copy_from_slice(&buffer);
}
}
}
req.meta.mark_completed();
}
#[cfg(feature = "logging")]
log::debug!("Completed communication request {request_id}");
Ok(())
}
pub fn test(&self, _request_id: usize) -> bool {
true
}
pub fn wait_all(&mut self) -> Result<(), KError> {
#[cfg(feature = "logging")]
log::debug!(
"Waiting for {} pending requests",
self.pending_requests.len()
);
let mut pending = Vec::new();
std::mem::swap(&mut pending, &mut self.pending_requests);
for mut req in pending {
#[cfg(feature = "mpi")]
{
if !req.meta.completed {
let rc =
unsafe { mpi::ffi::MPI_Wait(&mut req.handle, mpi::ffi::RSMPI_STATUS_IGNORE) };
if rc != 0 {
return Err(KError::SolveError(format!(
"MPI_Wait failed with code {rc}"
)));
}
}
}
if req.meta.comm_type == CommType::Recv {
if let Some(buffer) = req.buffer.take() {
self.comm_buffer.resize(buffer.len(), 0.0);
self.comm_buffer.copy_from_slice(&buffer);
}
}
}
Ok(())
}
}
pub struct DistributedTriangularSolver;
impl DistributedTriangularSolver {
pub fn forward_solve(
b: &[f64],
x: &mut [f64],
numeric_factor: &NumericFactorization,
distribution: &BlockCyclicDistribution,
comm: &UniverseComm,
comm_pattern: CommPattern,
overlap_comm: bool,
#[cfg(feature = "superlu3d")] grid3d: Option<&ProcessGrid3D>,
) -> Result<(), KError> {
#[cfg(feature = "logging")]
let _guard = StageGuard::new("DistributedForwardSolve");
let n = b.len();
if n == 0 {
return Ok(());
}
let block_size = 64; let num_blocks = n.div_ceil(block_size);
#[cfg(feature = "logging")]
log::debug!("Starting forward solve: n={n}, blocks={num_blocks}, pattern={comm_pattern:?}");
let mut solve_data = TriangularSolveData::new(
n,
block_size,
distribution,
numeric_factor,
numeric_factor.l_block_graph.clone(),
);
x.copy_from_slice(b);
for block_id in 0..num_blocks {
let block_start = block_id * block_size;
let current_block_size = solve_data.block_sizes[block_id];
let block_end = block_start + current_block_size;
if solve_data.block_owners[block_id] == distribution.grid.my_rank {
Self::solve_local_l_block(
&mut x[block_start..block_end],
&solve_data.local_l_factors,
block_id,
)?;
if overlap_comm {
Self::start_nonblocking_broadcast(
&mut solve_data,
&x[block_start..block_end],
block_id,
distribution,
comm_pattern,
comm,
#[cfg(feature = "superlu3d")]
grid3d,
)?;
}
} else if overlap_comm {
let owner_rank = solve_data.block_owners[block_id];
solve_data.irecv(current_block_size, owner_rank, block_id, block_id, comm)?;
solve_data.wait(block_id, Some(&mut x[block_start..block_end]))?;
}
let dependency_blocks = solve_data.dependency_graph[block_id].clone();
for dep_block in dependency_blocks {
if solve_data.block_owners[dep_block] != distribution.grid.my_rank {
if overlap_comm {
solve_data.wait(dep_block, None)?;
}
Self::apply_block_update(
&mut x[block_start..block_end],
&solve_data.comm_buffer,
dep_block,
block_id,
&solve_data.local_l_factors,
)?;
}
}
if !overlap_comm {
Self::synchronous_broadcast(
&mut x[block_start..block_end],
solve_data.block_owners[block_id],
block_id,
comm,
comm_pattern,
#[cfg(feature = "superlu3d")]
grid3d,
)?;
}
}
if overlap_comm {
solve_data.wait_all()?;
}
#[cfg(feature = "logging")]
log::debug!("Forward solve completed successfully");
Ok(())
}
pub fn backward_solve(
y: &[f64],
x: &mut [f64],
numeric_factor: &NumericFactorization,
distribution: &BlockCyclicDistribution,
comm: &UniverseComm,
comm_pattern: CommPattern,
overlap_comm: bool,
#[cfg(feature = "superlu3d")] grid3d: Option<&ProcessGrid3D>,
) -> Result<(), KError> {
#[cfg(feature = "logging")]
let _guard = StageGuard::new("DistributedBackwardSolve");
let n = y.len();
if n == 0 {
return Ok(());
}
let block_size = 64; let num_blocks = n.div_ceil(block_size);
#[cfg(feature = "logging")]
log::debug!(
"Starting backward solve: n={n}, blocks={num_blocks}, pattern={comm_pattern:?}"
);
let mut solve_data = TriangularSolveData::new(
n,
block_size,
distribution,
numeric_factor,
numeric_factor.u_block_graph.clone(),
);
x.copy_from_slice(y);
for block_id in (0..num_blocks).rev() {
let block_start = block_id * block_size;
let current_block_size = solve_data.block_sizes[block_id];
let block_end = block_start + current_block_size;
let dependency_blocks = solve_data.dependency_graph[block_id].clone();
for dep_block in dependency_blocks {
if solve_data.block_owners[dep_block] != distribution.grid.my_rank {
if overlap_comm {
solve_data.wait(dep_block, None)?;
}
Self::apply_block_update_backward(
&mut x[block_start..block_end],
&solve_data.comm_buffer,
dep_block,
block_id,
&solve_data.local_u_factors,
)?;
}
}
if solve_data.block_owners[block_id] == distribution.grid.my_rank {
Self::solve_local_u_block(
&mut x[block_start..block_end],
&solve_data.local_u_factors,
block_id,
)?;
if overlap_comm {
Self::start_nonblocking_broadcast(
&mut solve_data,
&x[block_start..block_end],
block_id,
distribution,
comm_pattern,
comm,
#[cfg(feature = "superlu3d")]
grid3d,
)?;
}
} else if overlap_comm {
let owner_rank = solve_data.block_owners[block_id];
solve_data.irecv(current_block_size, owner_rank, block_id, block_id, comm)?;
solve_data.wait(block_id, Some(&mut x[block_start..block_end]))?;
}
if !overlap_comm {
Self::synchronous_broadcast(
&mut x[block_start..block_end],
solve_data.block_owners[block_id],
block_id,
comm,
comm_pattern,
#[cfg(feature = "superlu3d")]
grid3d,
)?;
}
}
if overlap_comm {
solve_data.wait_all()?;
}
#[cfg(feature = "logging")]
log::debug!("Backward solve completed successfully");
Ok(())
}
fn solve_local_l_block(
x_block: &mut [f64],
l_factors: &[Panel],
block_id: usize,
) -> Result<(), KError> {
if let Some(panel) = l_factors.get(block_id) {
let m = panel.height;
let n = panel.width;
let x_len = x_block.len();
let mut x = faer::MatMut::from_column_major_slice_mut(x_block, x_len, 1);
let l = panel.as_faer();
let k = x_len.min(n).min(m);
let l_square = l.submatrix(0, 0, k, k);
let mut x_sub = x.rb_mut().submatrix_mut(0, 0, k, 1);
faer::linalg::triangular_solve::solve_unit_lower_triangular_in_place(
l_square,
x_sub.rb_mut(),
faer::Par::Seq,
);
}
Ok(())
}
fn solve_local_u_block(
x_block: &mut [f64],
u_factors: &[Panel],
block_id: usize,
) -> Result<(), KError> {
if let Some(panel) = u_factors.get(block_id) {
let m = panel.height;
let n = panel.width;
let x_len = x_block.len();
let mut x = faer::MatMut::from_column_major_slice_mut(x_block, x_len, 1);
let u = panel.as_faer();
let k = x_len.min(n).min(m);
let u_square = u.submatrix(0, 0, k, k);
let mut x_sub = x.rb_mut().submatrix_mut(0, 0, k, 1);
faer::linalg::triangular_solve::solve_upper_triangular_in_place(
u_square,
x_sub.rb_mut(),
faer::Par::Seq,
);
}
Ok(())
}
fn start_nonblocking_broadcast(
solve_data: &mut TriangularSolveData,
data: &[f64],
block_id: usize,
distribution: &BlockCyclicDistribution,
comm_pattern: CommPattern,
comm: &UniverseComm,
#[cfg(feature = "superlu3d")] grid3d: Option<&ProcessGrid3D>,
) -> Result<(), KError> {
let root_rank = solve_data
.block_owners
.get(block_id)
.copied()
.unwrap_or(distribution.grid.my_rank);
#[cfg(feature = "logging")]
log::debug!(
"Starting nonblocking broadcast from rank {root_rank} for block {block_id} using {comm_pattern:?}"
);
if distribution.grid.my_rank == root_rank {
for rank in 0..distribution.grid.total_procs {
if rank != root_rank {
solve_data.isend(data, rank, block_id, block_id, comm)?;
}
}
}
#[cfg(feature = "superlu3d")]
if let Some(g3) = grid3d {
let layers = g3.pdepth;
let left = 2 * g3.my_pdepth + 1;
let right = 2 * g3.my_pdepth + 2;
if left < layers {
let r = g3.coords_to_rank(g3.my_prow, g3.my_pcol, left);
solve_data.isend(data, r, (block_id << 8) + left, block_id, comm)?;
}
if right < layers {
let r = g3.coords_to_rank(g3.my_prow, g3.my_pcol, right);
solve_data.isend(data, r, (block_id << 8) + right, block_id, comm)?;
}
}
Ok(())
}
fn synchronous_broadcast(
data: &mut [f64],
root_rank: usize,
block_id: usize,
comm: &UniverseComm,
comm_pattern: CommPattern,
#[cfg(feature = "superlu3d")] grid3d: Option<&ProcessGrid3D>,
) -> Result<(), KError> {
#[cfg(feature = "logging")]
log::debug!(
"Synchronous broadcast from rank {root_rank} for block {block_id} using {comm_pattern:?}"
);
let size = comm.size();
if size > 1 {
let mut reqs = Vec::new();
if comm.rank() == root_rank {
for rank in 0..size {
if rank != root_rank {
reqs.push(comm.isend_to(data, rank as i32));
}
}
} else {
reqs.push(comm.irecv_from(data, root_rank as i32));
}
comm.wait_all(&mut reqs);
}
#[cfg(feature = "superlu3d")]
if let Some(g3) = grid3d {
let layers = g3.pdepth;
let left = 2 * g3.my_pdepth + 1;
let right = 2 * g3.my_pdepth + 2;
if left < layers {
let _ = g3.coords_to_rank(g3.my_prow, g3.my_pcol, left);
}
if right < layers {
let _ = g3.coords_to_rank(g3.my_prow, g3.my_pcol, right);
}
}
Ok(())
}
fn apply_block_update(
x_block: &mut [f64],
update_data: &[f64],
source_block: usize,
target_block: usize,
l_factors: &[Panel],
) -> Result<(), KError> {
if let Some(l_panel) = l_factors.get(target_block) {
let m = x_block.len();
let l = l_panel.as_faer();
if source_block < l.ncols() && source_block < update_data.len() {
let col = l.submatrix(0, source_block, m, 1);
let scalar = faer::MatRef::from_column_major_slice(
&update_data[source_block..source_block + 1],
1,
1,
);
let mut x = MatMut::from_column_major_slice_mut(x_block, m, 1);
faer::linalg::matmul::matmul(
x.rb_mut(),
faer::Accum::Add,
col,
scalar,
-1.0,
faer::Par::Seq,
);
}
}
Ok(())
}
fn apply_block_update_backward(
x_block: &mut [f64],
update_data: &[f64],
source_block: usize,
target_block: usize,
u_factors: &[Panel],
) -> Result<(), KError> {
if let Some(u_panel) = u_factors.get(target_block) {
let m = x_block.len();
let u = u_panel.as_faer();
if source_block < u.ncols() && source_block < update_data.len() {
let col = u.submatrix(0, source_block, m, 1);
let scalar = faer::MatRef::from_column_major_slice(
&update_data[source_block..source_block + 1],
1,
1,
);
let mut x = MatMut::from_column_major_slice_mut(x_block, m, 1);
faer::linalg::matmul::matmul(
x.rb_mut(),
faer::Accum::Add,
col,
scalar,
-1.0,
faer::Par::Seq,
);
}
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct SuperLuDistOptions {
pub process_grid: Option<(usize, usize)>,
pub column_permutation: ColumnPermutation,
pub diagonal_pivot_threshold: f64,
pub replace_tiny_pivots: bool,
pub iterative_refinement: IterativeRefinement,
pub print_level: u8,
pub static_pivoting: bool,
pub row_permutation: RowPermutation,
pub panel_size: Option<usize>,
pub enable_3d_factorization: bool,
pub process_grid_3d_depth: Option<usize>,
pub memory_tradeoff_factor: f64,
pub max_concurrent_panels: usize,
pub async_panel_updates: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ColumnPermutation {
Natural,
MmdAta,
Metis,
ParMetis,
User,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RowPermutation {
NoRowPerm,
LargeDiag,
User,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum IterativeRefinement {
NoRefine,
Single,
Double,
Extra,
}
impl Default for SuperLuDistOptions {
fn default() -> Self {
Self {
process_grid: None,
column_permutation: ColumnPermutation::MmdAta,
diagonal_pivot_threshold: 1.0,
replace_tiny_pivots: false,
iterative_refinement: IterativeRefinement::Double,
print_level: 0,
static_pivoting: false,
row_permutation: RowPermutation::LargeDiag,
panel_size: None,
enable_3d_factorization: false,
process_grid_3d_depth: None,
memory_tradeoff_factor: 1.0,
max_concurrent_panels: 1,
async_panel_updates: false,
}
}
}
impl SuperLuDistOptions {
#[inline]
pub fn enabled(&self, level: u8, required: u8) -> bool {
self.print_level >= required && level <= self.print_level
}
pub fn validate(&self, comm: Option<&UniverseComm>) -> Result<(), KError> {
if !(0.0..=1.0).contains(&self.diagonal_pivot_threshold) {
return Err(KError::InvalidInput(format!(
"diagonal_pivot_threshold={} must be in [0,1]",
self.diagonal_pivot_threshold
)));
}
if let Some(sz) = self.panel_size
&& sz == 0
{
return Err(KError::InvalidInput("panel_size must be > 0".into()));
}
if self.enable_3d_factorization && self.process_grid_3d_depth == Some(0) {
return Err(KError::InvalidInput("3D depth must be > 0".into()));
}
if let Some((r, c)) = self.process_grid {
if r == 0 || c == 0 {
return Err(KError::InvalidInput("process_grid dims must be > 0".into()));
}
if let Some(comm) = comm {
let sz = comm.size();
if r * c != sz {
return Err(KError::InvalidInput(format!(
"process_grid {r}x{c} does not match comm size {sz}"
)));
}
}
}
if self.memory_tradeoff_factor.is_nan() || self.memory_tradeoff_factor < 0.1 {
return Err(KError::InvalidInput(
"memory_tradeoff_factor must be >= 0.1".into(),
));
}
Ok(())
}
}
#[derive(Debug, Clone)]
#[allow(dead_code)]
struct Graph {
#[allow(dead_code)]
n: usize,
adj: Vec<Vec<usize>>,
}
#[allow(dead_code)]
impl Graph {
fn from_matrix_pattern(matrix: &CsrMatrix<f64>) -> Self {
let n = matrix.nrows();
let mut adj = vec![Vec::new(); n];
let row_ptrs = matrix.row_ptr();
let col_indices = matrix.col_idx();
for i in 0..n {
for idx in row_ptrs[i]..row_ptrs[i + 1] {
let j = col_indices[idx];
if i != j {
adj[i].push(j);
}
}
}
let mut transpose_edges = vec![Vec::new(); n];
for i in 0..n {
for &j in &adj[i] {
transpose_edges[j].push(i);
}
}
for i in 0..n {
adj[i].extend(&transpose_edges[i]);
adj[i].sort_unstable();
adj[i].dedup();
}
Self { n, adj }
}
fn degree(&self, v: usize) -> usize {
self.adj[v].len()
}
fn neighbors(&self, v: usize) -> &[usize] {
&self.adj[v]
}
fn eliminate_vertex(&mut self, v: usize, eliminated: &[bool]) -> Vec<(usize, usize)> {
let mut new_edges = Vec::new();
let neighbors: Vec<usize> = self.adj[v]
.iter()
.filter(|&&u| !eliminated[u])
.copied()
.collect();
for i in 0..neighbors.len() {
for j in i + 1..neighbors.len() {
let u = neighbors[i];
let w = neighbors[j];
if !self.adj[u].contains(&w) {
new_edges.push((u, w));
new_edges.push((w, u));
}
}
}
for (u, v) in &new_edges {
self.adj[*u].push(*v);
self.adj[*u].sort_unstable();
}
new_edges
}
}
#[derive(Debug, Clone)]
pub struct EliminationTree {
pub parent: Vec<usize>,
pub children: Vec<Vec<usize>>,
pub post_order: Vec<usize>,
}
impl EliminationTree {
fn from_symbolic_factor(n: usize, l_pattern: &HashMap<(usize, usize), bool>) -> Self {
let mut parent = vec![n; n];
for j in 0..n {
for i in (j + 1)..n {
if l_pattern.contains_key(&(i, j)) {
parent[j] = i;
break;
}
}
}
let mut children = vec![Vec::new(); n];
for (child, &par) in parent.iter().enumerate() {
if par < n {
children[par].push(child);
}
}
let mut post_order = Vec::new();
let mut visited = vec![false; n];
fn dfs_post_order(
v: usize,
children: &[Vec<usize>],
visited: &mut [bool],
post_order: &mut Vec<usize>,
) {
visited[v] = true;
for &child in &children[v] {
if !visited[child] {
dfs_post_order(child, children, visited, post_order);
}
}
post_order.push(v);
}
for i in 0..n {
if parent[i] == n && !visited[i] {
dfs_post_order(i, &children, &mut visited, &mut post_order);
}
}
Self {
parent,
children,
post_order,
}
}
pub fn ancestors(&self, node: usize) -> Vec<usize> {
let mut ancestors = Vec::new();
let mut current = self.parent[node];
while current < self.parent.len() {
ancestors.push(current);
current = self.parent[current];
}
ancestors
}
}
pub struct OrderingAlgorithms;
impl OrderingAlgorithms {
pub fn natural_ordering(n: usize) -> Vec<usize> {
(0..n).collect()
}
pub fn amd_ordering(matrix: &CsrMatrix<f64>) -> Vec<usize> {
let n = matrix.nrows();
let mut graph = Graph::from_matrix_pattern(matrix);
let mut perm = Vec::new();
let mut eliminated = vec![false; n];
for _ in 0..n {
let mut min_degree = usize::MAX;
let mut min_vertex = 0;
for v in 0..n {
if !eliminated[v] {
let degree = graph.adj[v].iter().filter(|&&u| !eliminated[u]).count();
if degree < min_degree {
min_degree = degree;
min_vertex = v;
}
}
}
perm.push(min_vertex);
eliminated[min_vertex] = true;
graph.eliminate_vertex(min_vertex, &eliminated);
}
perm
}
pub fn mmd_ata_ordering(matrix: &CsrMatrix<f64>) -> Vec<usize> {
let n = matrix.nrows();
if n == 0 {
return Vec::new();
}
let ata_graph = Self::build_ata_graph(matrix);
let mut degree = vec![0; n];
let mut eliminated = vec![false; n];
let mut ordering = Vec::with_capacity(n);
let mut adj_lists = ata_graph.adj.clone();
for i in 0..n {
degree[i] = adj_lists[i].len();
}
for step in 0..n {
let pivot = Self::select_minimum_degree_vertex(°ree, &eliminated);
if pivot >= n {
break;
}
ordering.push(pivot);
eliminated[pivot] = true;
#[cfg(feature = "logging")]
if step % 1000 == 0 {
log::debug!(
"MMD step {}/{}, pivot {} with degree {}",
step,
n,
pivot,
degree[pivot]
);
}
#[cfg(not(feature = "logging"))]
let _ = step;
let pivot_neighbors: Vec<usize> = adj_lists[pivot]
.iter()
.filter(|&&v| !eliminated[v])
.copied()
.collect();
let new_edges =
Self::eliminate_vertex_mmd(pivot, &pivot_neighbors, &mut adj_lists, &eliminated);
Self::update_degrees_after_elimination(
pivot,
&pivot_neighbors,
&new_edges,
&mut degree,
&eliminated,
&adj_lists,
);
}
ordering
}
pub fn mmd_ata_ordering_distributed(
matrix: &CsrMatrix<f64>,
comm: &UniverseComm,
distribution: &BlockCyclicDistribution,
) -> Result<Vec<usize>, KError> {
let n = matrix.nrows();
let global_pattern = if comm.size() <= 1 {
matrix.clone()
} else {
#[cfg(feature = "mpi")]
{
match comm {
UniverseComm::Mpi(comm_impl) => {
let global_rows = distribution.global_rows;
let global_cols = distribution.global_cols;
let mut adjacency = vec![Vec::<usize>::new(); global_rows];
let rp = matrix.row_ptr();
let ci = matrix.col_idx();
let local_rows = distribution.local_rows;
let mut encoded: Vec<usize> = Vec::new();
for local_row in 0..local_rows {
let global_row = distribution.local_to_global_row(local_row);
let start = rp[global_row];
let end = rp[global_row + 1];
encoded.push(global_row);
encoded.push(end - start);
encoded.extend_from_slice(&ci[start..end]);
}
let local_len = encoded.len() as i32;
let mut lengths = vec![0i32; comm_impl.size];
comm_impl
.world
.all_gather_into(&local_len, &mut lengths[..]);
let lengths: Vec<usize> = lengths.iter().map(|&l| l as usize).collect();
let max_len = lengths.iter().copied().max().unwrap_or(0);
if max_len > 0 {
let mut padded = encoded.clone();
padded.resize(max_len, usize::MAX);
let mut gathered = vec![0usize; max_len * comm_impl.size];
comm_impl
.world
.all_gather_into(&padded[..], &mut gathered[..]);
for (rank_idx, &len) in lengths.iter().enumerate() {
let chunk = &gathered[rank_idx * max_len..rank_idx * max_len + len];
let mut idx = 0;
while idx < chunk.len() {
if idx + 1 >= chunk.len() {
break;
}
let global_row = chunk[idx];
idx += 1;
let nnz = chunk[idx];
idx += 1;
if idx + nnz > chunk.len() {
break;
}
adjacency[global_row].extend_from_slice(&chunk[idx..idx + nnz]);
idx += nnz;
}
}
}
let mut row_ptr = Vec::with_capacity(global_rows + 1);
let mut col_idx = Vec::new();
row_ptr.push(0);
for cols in adjacency.iter_mut() {
cols.sort_unstable();
cols.dedup();
col_idx.extend_from_slice(cols);
row_ptr.push(col_idx.len());
}
let values = vec![0.0; col_idx.len()];
CsrMatrix::from_csr(global_rows, global_cols, row_ptr, col_idx, values)
}
_ => matrix.clone(),
}
}
#[cfg(not(feature = "mpi"))]
{
let _ = distribution;
matrix.clone()
}
};
let graph = Self::build_ata_graph(&global_pattern);
#[cfg(not(feature = "logging"))]
let _ = comm;
let mut adj_lists = graph.adj;
let mut ordering = Vec::with_capacity(n);
let mut eliminated = vec![false; n];
let mut degree: Vec<usize> = adj_lists.iter().map(|adj| adj.len()).collect();
#[cfg(feature = "logging")]
if comm.rank() == 0 {
log::debug!("Starting distributed MMD ordering for matrix {n}x{n}");
}
for step in 0..n {
#[cfg(not(feature = "logging"))]
let _ = step;
let pivot = Self::select_minimum_degree_vertex(°ree, &eliminated);
if pivot >= n {
break;
}
ordering.push(pivot);
eliminated[pivot] = true;
#[cfg(feature = "logging")]
if step % 1000 == 0 && comm.rank() == 0 {
log::debug!(
"Distributed MMD step {}/{}, pivot {} with degree {}",
step,
n,
pivot,
degree[pivot]
);
}
let pivot_neighbors: Vec<usize> = adj_lists[pivot]
.iter()
.filter(|&&v| !eliminated[v])
.copied()
.collect();
let new_edges =
Self::eliminate_vertex_mmd(pivot, &pivot_neighbors, &mut adj_lists, &eliminated);
Self::update_degrees_after_elimination(
pivot,
&pivot_neighbors,
&new_edges,
&mut degree,
&eliminated,
&adj_lists,
);
}
Ok(ordering)
}
fn build_ata_graph(matrix: &CsrMatrix<f64>) -> Graph {
let n = matrix.nrows();
let mut adj = vec![std::collections::BTreeSet::new(); n];
let row_ptrs = matrix.row_ptr();
let col_indices = matrix.col_idx();
for i in 0..n {
for idx in row_ptrs[i]..row_ptrs[i + 1] {
let j = col_indices[idx];
if i != j && j < n {
adj[i].insert(j);
}
}
}
for i in 0..n {
for idx in row_ptrs[i]..row_ptrs[i + 1] {
let j = col_indices[idx];
if i != j && j < n {
adj[j].insert(i);
}
}
}
let adj_vec: Vec<Vec<usize>> = adj
.into_iter()
.map(|set| set.into_iter().collect())
.collect();
Graph { n, adj: adj_vec }
}
fn select_minimum_degree_vertex(degree: &[usize], eliminated: &[bool]) -> usize {
let mut min_degree = usize::MAX;
let mut min_vertex = usize::MAX;
for (i, °) in degree.iter().enumerate() {
if !eliminated[i] && deg < min_degree {
min_degree = deg;
min_vertex = i;
}
}
min_vertex
}
fn eliminate_vertex_mmd(
pivot: usize,
neighbors: &[usize],
adj_lists: &mut [Vec<usize>],
eliminated: &[bool],
) -> Vec<(usize, usize)> {
let mut new_edges = Vec::new();
for i in 0..neighbors.len() {
for j in (i + 1)..neighbors.len() {
let u = neighbors[i];
let v = neighbors[j];
if !eliminated[u] && !eliminated[v] {
if !adj_lists[u].contains(&v) {
adj_lists[u].push(v);
adj_lists[v].push(u);
new_edges.push((u, v));
}
}
}
}
for &neighbor in neighbors {
if !eliminated[neighbor] {
adj_lists[neighbor].retain(|&x| x != pivot);
}
}
for &neighbor in neighbors {
if !eliminated[neighbor] {
adj_lists[neighbor].sort_unstable();
}
}
new_edges
}
fn update_degrees_after_elimination(
pivot: usize,
pivot_neighbors: &[usize],
new_edges: &[(usize, usize)],
degree: &mut [usize],
eliminated: &[bool],
adj_lists: &[Vec<usize>],
) {
degree[pivot] = 0;
let mut affected_vertices = std::collections::HashSet::new();
for &v in pivot_neighbors {
if !eliminated[v] {
affected_vertices.insert(v);
}
}
for &(u, v) in new_edges {
if !eliminated[u] {
affected_vertices.insert(u);
}
if !eliminated[v] {
affected_vertices.insert(v);
}
}
for &v in &affected_vertices {
if !eliminated[v] {
degree[v] = adj_lists[v].iter().filter(|&&u| !eliminated[u]).count();
}
}
}
pub fn metis_ordering(matrix: &CsrMatrix<f64>) -> Result<Vec<usize>, KError> {
#[cfg(feature = "logging")]
log::warn!("METIS ordering not implemented, falling back to AMD");
Ok(Self::amd_ordering(matrix))
}
pub fn parmetis_ordering(
matrix: &CsrMatrix<f64>,
comm: &UniverseComm,
distribution: &BlockCyclicDistribution,
) -> Result<Vec<usize>, KError> {
if comm.size() <= 1 {
return Ok(Self::mmd_ata_ordering(matrix));
}
#[cfg(feature = "logging")]
if comm.rank() == 0 {
log::debug!("Using distributed MMD ordering for ParMETIS path");
}
Self::mmd_ata_ordering_distributed(matrix, comm, distribution)
}
}
pub struct SymbolicFactorizer;
impl SymbolicFactorizer {
pub fn compute_symbolic_pattern(
matrix: &CsrMatrix<f64>,
col_perm: &[usize],
row_perm: &[usize],
) -> Result<HashMap<(usize, usize), bool>, KError> {
let n = matrix.nrows();
let etree = Self::build_elimination_tree_from_matrix(matrix, col_perm, row_perm)?;
let mut l_pattern = HashMap::new();
let mut visited = vec![false; n];
let mut reach_set = Vec::new();
let row_ptrs = matrix.row_ptr();
let col_indices = matrix.col_idx();
for k in 0..n {
visited.fill(false);
reach_set.clear();
Self::compute_reach_set(
k,
&etree,
row_ptrs,
col_indices,
row_perm,
col_perm,
&mut visited,
&mut reach_set,
);
for &i in &reach_set {
if i >= k {
l_pattern.insert((i, k), true);
}
}
l_pattern.insert((k, k), true);
}
#[cfg(feature = "logging")]
log::debug!(
"Symbolic factorization computed {} nonzeros in L factor",
l_pattern.len()
);
Ok(l_pattern)
}
fn build_elimination_tree_from_matrix(
matrix: &CsrMatrix<f64>,
col_perm: &[usize],
row_perm: &[usize],
) -> Result<Vec<usize>, KError> {
let n = matrix.nrows();
let mut parent = vec![n; n]; let mut ancestor = vec![0; n];
let row_ptrs = matrix.row_ptr();
let col_indices = matrix.col_idx();
for k in 0..n {
parent[k] = n; ancestor[k] = k;
let perm_row = row_perm[k];
let start = row_ptrs[perm_row];
let end = row_ptrs[perm_row + 1];
for idx in start..end {
let orig_col = col_indices[idx];
if let Some(j) = col_perm.iter().position(|&c| c == orig_col)
&& j < k
{
let mut root = j;
while ancestor[root] != root && ancestor[root] < k {
root = ancestor[root];
}
if parent[root] == n {
parent[root] = k;
}
ancestor[j] = k;
}
}
}
Ok(parent)
}
fn compute_reach_set(
col: usize,
etree: &[usize],
row_ptrs: &[usize],
col_indices: &[usize],
row_perm: &[usize],
col_perm: &[usize],
visited: &mut [bool],
reach_set: &mut Vec<usize>,
) {
let n = etree.len();
if col < n && !visited[col] {
Self::dfs_reach(col, etree, visited, reach_set);
}
if col < row_perm.len() {
let perm_row = row_perm[col];
if perm_row < row_ptrs.len() - 1 {
let start = row_ptrs[perm_row];
let end = row_ptrs[perm_row + 1];
for idx in start..end {
let orig_col = col_indices[idx];
if let Some(j) = col_perm.iter().position(|&c| c == orig_col)
&& j < col
&& !visited[j]
{
Self::dfs_reach(j, etree, visited, reach_set);
}
}
}
}
}
fn dfs_reach(node: usize, etree: &[usize], visited: &mut [bool], reach_set: &mut Vec<usize>) {
let n = etree.len();
if node >= n || visited[node] {
return;
}
visited[node] = true;
reach_set.push(node);
let parent = etree[node];
if parent < n {
Self::dfs_reach(parent, etree, visited, reach_set);
}
}
pub fn build_elimination_tree(
n: usize,
l_pattern: &HashMap<(usize, usize), bool>,
) -> EliminationTree {
EliminationTree::from_symbolic_factor(n, l_pattern)
}
}
pub struct SuperLuDistData {
pub process_grid: ProcessGrid,
pub distribution: BlockCyclicDistribution,
pub comm: UniverseComm,
pub options: SuperLuDistOptions,
pub factored: bool,
pub local_matrix: Option<CsrMatrix<f64>>,
symbolic_factor: Option<SymbolicFactorization>,
numeric_factor: Option<NumericFactorization>,
solve_workspace: Option<SolveWorkspace>,
}
#[derive(Debug, Clone)]
pub struct SymbolicFactorization {
pub col_perm: Vec<usize>,
pub row_perm: Vec<usize>,
pub etree: EliminationTree,
pub l_pattern: HashMap<(usize, usize), bool>,
pub u_pattern: HashMap<(usize, usize), bool>,
}
#[derive(Debug)]
pub struct SolveWorkspace {
pub workspace: SuperLuDistWorkspace,
pub process_vectors: HashMap<usize, Vec<f64>>,
pub global_vectors: HashMap<String, Vec<f64>>,
}
impl SuperLuDistData {
pub fn symbolic_factor(&self) -> Option<&SymbolicFactorization> {
self.symbolic_factor.as_ref()
}
pub fn numeric_factor(&self) -> Option<&NumericFactorization> {
self.numeric_factor.as_ref()
}
pub fn solve_workspace(&self) -> Option<&SolveWorkspace> {
self.solve_workspace.as_ref()
}
pub fn set_symbolic_factor(&mut self, factor: SymbolicFactorization) {
self.symbolic_factor = Some(factor);
}
pub fn set_numeric_factor(&mut self, factor: NumericFactorization) {
self.numeric_factor = Some(factor);
}
pub fn set_solve_workspace(&mut self, workspace: SolveWorkspace) {
self.solve_workspace = Some(workspace);
}
}
#[derive(Debug, Clone)]
pub struct RefinementConfig {
pub max_iterations: usize,
pub tolerance: f64,
pub relative_tolerance: f64,
pub min_improvement_factor: f64,
}
impl Default for RefinementConfig {
fn default() -> Self {
Self {
max_iterations: 5,
tolerance: 1e-12,
relative_tolerance: 1e-6,
min_improvement_factor: 0.9,
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ResidualMethod {
Standard,
Scaled,
ComponentWise,
}
#[derive(Debug)]
pub struct RefinementEngine {
config: RefinementConfig,
residual_method: ResidualMethod,
residual_workspace: Vec<f64>,
correction_workspace: Vec<f64>,
matvec_workspace: Vec<f64>,
last_stats: Option<RefinementStats>,
}
#[derive(Debug, Clone)]
pub struct RefinementStats {
pub iterations: usize,
pub initial_residual_norm: f64,
pub final_residual_norm: f64,
pub residual_history: Vec<f64>,
pub converged: bool,
pub convergence_reason: RefinementConvergence,
pub refinement_time: f64,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RefinementConvergence {
AbsoluteTolerance,
RelativeTolerance,
MaxIterations,
Stagnation,
Divergence,
}
impl RefinementEngine {
pub fn new(config: RefinementConfig, residual_method: ResidualMethod) -> Self {
Self {
config,
residual_method,
residual_workspace: Vec::new(),
correction_workspace: Vec::new(),
matvec_workspace: Vec::new(),
last_stats: None,
}
}
pub fn with_defaults() -> Self {
Self::new(RefinementConfig::default(), ResidualMethod::Standard)
}
pub fn setup_workspace(&mut self, n: usize) {
self.residual_workspace.resize(n, 0.0);
self.correction_workspace.resize(n, 0.0);
self.matvec_workspace.resize(n, 0.0);
}
pub fn refine_solution(
&mut self,
matrix: &CsrMatrix<f64>,
rhs: &[f64],
solution: &mut [f64],
superlu_data: &SuperLuDistData,
comm: &UniverseComm,
) -> Result<RefinementStats, KError> {
let n = solution.len();
if n != rhs.len() {
return Err(KError::InvalidInput(
"Solution and RHS dimension mismatch".to_string(),
));
}
self.setup_workspace(n);
let start_time = std::time::Instant::now();
let mut stats = RefinementStats {
iterations: 0,
initial_residual_norm: 0.0,
final_residual_norm: 0.0,
residual_history: Vec::new(),
converged: false,
convergence_reason: RefinementConvergence::MaxIterations,
refinement_time: 0.0,
};
let mut residual_workspace = self.residual_workspace.clone();
let mut correction_workspace = self.correction_workspace.clone();
let mut matvec_workspace = self.matvec_workspace.clone();
Self::compute_residual_static(
matrix,
rhs,
solution,
&mut residual_workspace,
&mut matvec_workspace,
self.residual_method,
comm,
)?;
let initial_residual_norm = Self::compute_residual_norm_static(&residual_workspace, comm)?;
stats.initial_residual_norm = initial_residual_norm;
stats.final_residual_norm = initial_residual_norm;
stats.residual_history.push(initial_residual_norm);
if self.check_convergence(initial_residual_norm, initial_residual_norm, 0) {
stats.converged = true;
stats.convergence_reason = RefinementConvergence::AbsoluteTolerance;
stats.refinement_time = start_time.elapsed().as_secs_f64();
self.last_stats = Some(stats.clone());
return Ok(stats);
}
let mut previous_residual_norm = initial_residual_norm;
for iter in 0..self.config.max_iterations {
stats.iterations = iter + 1;
Self::solve_correction_static(
&residual_workspace,
&mut correction_workspace,
superlu_data,
comm,
)?;
for i in 0..n {
solution[i] += correction_workspace[i];
}
Self::compute_residual_static(
matrix,
rhs,
solution,
&mut residual_workspace,
&mut matvec_workspace,
self.residual_method,
comm,
)?;
let residual_norm = Self::compute_residual_norm_static(&residual_workspace, comm)?;
stats.final_residual_norm = residual_norm;
stats.residual_history.push(residual_norm);
if self.check_convergence(residual_norm, initial_residual_norm, iter + 1) {
stats.converged = true;
stats.convergence_reason = if residual_norm <= self.config.tolerance {
RefinementConvergence::AbsoluteTolerance
} else {
RefinementConvergence::RelativeTolerance
};
break;
}
let improvement_factor = residual_norm / previous_residual_norm;
if improvement_factor > self.config.min_improvement_factor {
stats.convergence_reason = RefinementConvergence::Stagnation;
break;
}
if residual_norm > initial_residual_norm * 10.0 {
stats.convergence_reason = RefinementConvergence::Divergence;
break;
}
previous_residual_norm = residual_norm;
}
stats.refinement_time = start_time.elapsed().as_secs_f64();
self.last_stats = Some(stats.clone());
Ok(stats)
}
fn compute_residual_static(
matrix: &CsrMatrix<f64>,
rhs: &[f64],
solution: &[f64],
residual: &mut [f64],
matvec_workspace: &mut [f64],
residual_method: ResidualMethod,
comm: &UniverseComm,
) -> Result<(), KError> {
residual.copy_from_slice(rhs);
Self::distributed_sparse_matvec_static(matrix, solution, matvec_workspace, comm)?;
for i in 0..residual.len() {
residual[i] -= matvec_workspace[i];
}
match residual_method {
ResidualMethod::Standard => {
}
ResidualMethod::Scaled => {
let rhs_norm = Self::compute_vector_norm_static(rhs, comm)?;
if rhs_norm > 0.0 {
for r in residual.iter_mut() {
*r /= rhs_norm;
}
}
}
ResidualMethod::ComponentWise => {
for i in 0..residual.len() {
let scale = f64::max(rhs[i].abs(), matvec_workspace[i].abs());
if scale > 0.0 {
residual[i] /= scale;
}
}
}
}
Ok(())
}
fn distributed_sparse_matvec_static(
matrix: &CsrMatrix<f64>,
x: &[f64],
y: &mut [f64],
_comm: &UniverseComm,
) -> Result<(), KError> {
let row_ptrs = matrix.row_ptr();
let col_indices = matrix.col_idx();
let values = matrix.values();
y.fill(0.0);
for i in 0..matrix.nrows() {
for idx in row_ptrs[i]..row_ptrs[i + 1] {
let j = col_indices[idx];
let val = values[idx];
y[i] += val * x[j];
}
}
Ok(())
}
fn solve_correction_static(
residual: &[f64],
correction: &mut [f64],
superlu_data: &SuperLuDistData,
comm: &UniverseComm,
) -> Result<(), KError> {
let numeric_factor = superlu_data
.numeric_factor
.as_ref()
.ok_or_else(|| KError::SolveError("Numeric factorization not available".to_string()))?;
correction.copy_from_slice(residual);
let mut temp_result = correction.to_vec();
DistributedTriangularSolver::forward_solve(
residual,
&mut temp_result,
numeric_factor,
&superlu_data.distribution,
comm,
CommPattern::PointToPoint,
false,
#[cfg(feature = "superlu3d")]
None,
)?;
DistributedTriangularSolver::backward_solve(
&temp_result,
correction,
numeric_factor,
&superlu_data.distribution,
comm,
CommPattern::PointToPoint,
false,
#[cfg(feature = "superlu3d")]
None,
)?;
Ok(())
}
fn compute_residual_norm_static(residual: &[f64], comm: &UniverseComm) -> Result<f64, KError> {
let local_norm_sq: f64 = residual.iter().map(|x| x * x).sum();
let global_norm_sq = comm.all_reduce_f64(local_norm_sq);
Ok(global_norm_sq.sqrt())
}
fn compute_vector_norm_static(vector: &[f64], comm: &UniverseComm) -> Result<f64, KError> {
let local_norm_sq: f64 = vector.iter().map(|x| x * x).sum();
let global_norm_sq = comm.all_reduce_f64(local_norm_sq);
Ok(global_norm_sq.sqrt())
}
fn check_convergence(&self, current_norm: f64, initial_norm: f64, iteration: usize) -> bool {
if iteration == 0 {
return false; }
if current_norm <= self.config.tolerance {
return true;
}
if initial_norm > 0.0 && current_norm / initial_norm <= self.config.relative_tolerance {
return true;
}
false
}
pub fn last_stats(&self) -> Option<&RefinementStats> {
self.last_stats.as_ref()
}
pub fn config(&self) -> &RefinementConfig {
&self.config
}
pub fn set_config(&mut self, config: RefinementConfig) {
self.config = config;
}
pub fn set_residual_method(&mut self, method: ResidualMethod) {
self.residual_method = method;
}
}
#[derive(Debug)]
pub struct MemoryPool {
f64_pools: std::collections::HashMap<usize, Vec<Vec<f64>>>,
usize_pools: std::collections::HashMap<usize, Vec<Vec<usize>>>,
max_vectors_per_size: usize,
memory_limit: usize,
current_memory_usage: usize,
}
impl MemoryPool {
pub fn new(max_vectors_per_size: usize, memory_limit_mb: usize) -> Self {
Self {
f64_pools: std::collections::HashMap::new(),
usize_pools: std::collections::HashMap::new(),
max_vectors_per_size,
memory_limit: memory_limit_mb * 1024 * 1024, current_memory_usage: 0,
}
}
pub fn get_f64_vector(&mut self, size: usize) -> Vec<f64> {
if let Some(pool) = self.f64_pools.get_mut(&size)
&& let Some(mut vec) = pool.pop()
{
vec.clear();
vec.resize(size, 0.0);
return vec;
}
vec![0.0; size]
}
pub fn return_f64_vector(&mut self, mut vec: Vec<f64>) {
let size = vec.capacity();
let memory_size = size * std::mem::size_of::<f64>();
if self.current_memory_usage + memory_size > self.memory_limit {
return; }
let pool = self.f64_pools.entry(size).or_default();
if pool.len() < self.max_vectors_per_size {
vec.clear();
pool.push(vec);
self.current_memory_usage += memory_size;
}
}
pub fn get_usize_vector(&mut self, size: usize) -> Vec<usize> {
if let Some(pool) = self.usize_pools.get_mut(&size)
&& let Some(mut vec) = pool.pop()
{
vec.clear();
vec.resize(size, 0);
return vec;
}
vec![0; size]
}
pub fn return_usize_vector(&mut self, mut vec: Vec<usize>) {
let size = vec.capacity();
let memory_size = size * std::mem::size_of::<usize>();
if self.current_memory_usage + memory_size > self.memory_limit {
return;
}
let pool = self.usize_pools.entry(size).or_default();
if pool.len() < self.max_vectors_per_size {
vec.clear();
pool.push(vec);
self.current_memory_usage += memory_size;
}
}
pub fn clear(&mut self) {
self.f64_pools.clear();
self.usize_pools.clear();
self.current_memory_usage = 0;
}
pub fn memory_usage(&self) -> usize {
self.current_memory_usage
}
pub fn memory_stats(&self) -> MemoryStats {
let f64_vectors: usize = self.f64_pools.values().map(|pool| pool.len()).sum();
let usize_vectors: usize = self.usize_pools.values().map(|pool| pool.len()).sum();
MemoryStats {
total_memory_bytes: self.current_memory_usage,
f64_vectors_pooled: f64_vectors,
usize_vectors_pooled: usize_vectors,
f64_pool_sizes: self.f64_pools.len(),
usize_pool_sizes: self.usize_pools.len(),
}
}
}
#[derive(Debug, Clone)]
pub struct MemoryStats {
pub total_memory_bytes: usize,
pub f64_vectors_pooled: usize,
pub usize_vectors_pooled: usize,
pub f64_pool_sizes: usize,
pub usize_pool_sizes: usize,
}
#[derive(Debug)]
pub struct CommBufferManager {
send_buffers: HashMap<usize, Vec<f64>>,
recv_buffers: HashMap<usize, Vec<f64>>,
local_buffer: Vec<f64>,
max_buffer_size: usize,
memory_pool: MemoryPool,
}
impl CommBufferManager {
pub fn new(max_buffer_size: usize, memory_limit_mb: usize) -> Self {
Self {
send_buffers: HashMap::new(),
recv_buffers: HashMap::new(),
local_buffer: Vec::new(),
max_buffer_size,
memory_pool: MemoryPool::new(4, memory_limit_mb / 2), }
}
pub fn get_send_buffer(&mut self, process: usize, size: usize) -> &mut Vec<f64> {
let buffer_size = size.min(self.max_buffer_size);
let buffer = self
.send_buffers
.entry(process)
.or_insert_with(|| self.memory_pool.get_f64_vector(buffer_size));
if buffer.len() != buffer_size {
*buffer = self.memory_pool.get_f64_vector(buffer_size);
}
buffer
}
pub fn get_recv_buffer(&mut self, process: usize, size: usize) -> &mut Vec<f64> {
let buffer_size = size.min(self.max_buffer_size);
let buffer = self
.recv_buffers
.entry(process)
.or_insert_with(|| self.memory_pool.get_f64_vector(buffer_size));
if buffer.len() != buffer_size {
*buffer = self.memory_pool.get_f64_vector(buffer_size);
}
buffer
}
pub fn get_local_buffer(&mut self, size: usize) -> &mut Vec<f64> {
let buffer_size = size.min(self.max_buffer_size);
if self.local_buffer.len() != buffer_size {
self.local_buffer = self.memory_pool.get_f64_vector(buffer_size);
}
&mut self.local_buffer
}
pub fn clear_buffers(&mut self) {
for (_, buffer) in self.send_buffers.drain() {
self.memory_pool.return_f64_vector(buffer);
}
for (_, buffer) in self.recv_buffers.drain() {
self.memory_pool.return_f64_vector(buffer);
}
if !self.local_buffer.is_empty() {
let buffer = std::mem::take(&mut self.local_buffer);
self.memory_pool.return_f64_vector(buffer);
}
}
pub fn memory_stats(&self) -> MemoryStats {
self.memory_pool.memory_stats()
}
}
#[derive(Debug)]
pub struct SuperLuDistWorkspace {
temp_vectors: HashMap<String, Vec<f64>>,
comm_buffers: CommBufferManager,
memory_pool: MemoryPool,
config: WorkspaceConfig,
vector_sizes: HashMap<String, usize>,
}
#[derive(Debug, Clone)]
pub struct WorkspaceConfig {
pub memory_limit_mb: usize,
pub max_vectors_per_size: usize,
pub max_comm_buffer_size: usize,
pub aggressive_reuse: bool,
pub preallocation_strategy: PreallocationStrategy,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum PreallocationStrategy {
None,
MatrixSize,
ProcessGrid,
BlockSize,
Full,
}
impl Default for WorkspaceConfig {
fn default() -> Self {
Self {
memory_limit_mb: 512, max_vectors_per_size: 8,
max_comm_buffer_size: 1024 * 1024, aggressive_reuse: true,
preallocation_strategy: PreallocationStrategy::MatrixSize,
}
}
}
impl SuperLuDistWorkspace {
pub fn new() -> Self {
Self::with_config(WorkspaceConfig::default())
}
pub fn with_config(config: WorkspaceConfig) -> Self {
let memory_pool = MemoryPool::new(
config.max_vectors_per_size,
config.memory_limit_mb / 2, );
let comm_buffers = CommBufferManager::new(
config.max_comm_buffer_size,
config.memory_limit_mb / 2, );
Self {
temp_vectors: HashMap::new(),
comm_buffers,
memory_pool,
config,
vector_sizes: HashMap::new(),
}
}
pub fn setup_for_problem(
&mut self,
matrix_size: usize,
process_grid: &ProcessGrid,
block_size: usize,
) -> Result<(), KError> {
let (local_start, local_end) =
contiguous_partition(matrix_size, process_grid.my_rank, process_grid.total_procs);
let local_size = local_end.saturating_sub(local_start);
let max_local_size = matrix_size.div_ceil(process_grid.total_procs);
let panel_size = self.config.max_comm_buffer_size.min(block_size * 10);
self.vector_sizes
.insert("solution".to_string(), matrix_size);
self.vector_sizes
.insert("residual".to_string(), matrix_size);
self.vector_sizes
.insert("local_work".to_string(), local_size);
self.vector_sizes
.insert("local_work_max".to_string(), max_local_size);
self.vector_sizes
.insert("panel_work".to_string(), panel_size);
self.vector_sizes
.insert("comm_buffer".to_string(), panel_size);
match self.config.preallocation_strategy {
PreallocationStrategy::None => {
}
PreallocationStrategy::MatrixSize => {
self.preallocate_vector("solution", matrix_size)?;
self.preallocate_vector("residual", matrix_size)?;
}
PreallocationStrategy::ProcessGrid => {
self.preallocate_vector("local_work", local_size)?;
for p in 0..process_grid.total_procs {
self.comm_buffers.get_send_buffer(p, panel_size);
self.comm_buffers.get_recv_buffer(p, panel_size);
}
}
PreallocationStrategy::BlockSize => {
self.preallocate_vector("panel_work", panel_size)?;
self.preallocate_vector("block_work", block_size)?;
}
PreallocationStrategy::Full => {
self.preallocate_vector("solution", matrix_size)?;
self.preallocate_vector("residual", matrix_size)?;
self.preallocate_vector("local_work", local_size)?;
self.preallocate_vector("panel_work", panel_size)?;
self.preallocate_vector("block_work", block_size)?;
for p in 0..process_grid.total_procs {
self.comm_buffers.get_send_buffer(p, panel_size);
self.comm_buffers.get_recv_buffer(p, panel_size);
}
}
}
Ok(())
}
fn preallocate_vector(&mut self, name: &str, size: usize) -> Result<(), KError> {
let vector = self.memory_pool.get_f64_vector(size);
self.temp_vectors.insert(name.to_string(), vector);
Ok(())
}
pub fn get_temp_vector(&mut self, name: &str, size: usize) -> &mut Vec<f64> {
let expected_size = self.vector_sizes.get(name).copied().unwrap_or(size);
let actual_size = size.max(expected_size);
let vector = self
.temp_vectors
.entry(name.to_string())
.or_insert_with(|| self.memory_pool.get_f64_vector(actual_size));
if vector.len() < actual_size {
vector.resize(actual_size, 0.0);
} else if vector.len() > actual_size * 2 {
vector.resize(actual_size, 0.0);
vector.shrink_to_fit();
} else {
vector.fill(0.0);
}
vector
}
pub fn cleanup_unused_vectors(&mut self) {
let to_remove: Vec<String> = self
.temp_vectors
.keys()
.filter(|name| !self.vector_sizes.contains_key(*name))
.cloned()
.collect();
for name in to_remove {
if let Some(vector) = self.temp_vectors.remove(&name) {
self.memory_pool.return_f64_vector(vector);
}
}
}
pub fn return_temp_vector(&mut self, name: &str) {
if self.config.aggressive_reuse
&& let Some(vector) = self.temp_vectors.remove(name)
{
self.memory_pool.return_f64_vector(vector);
}
}
pub fn get_comm_buffers(&mut self) -> &mut CommBufferManager {
&mut self.comm_buffers
}
pub fn clear_temp_data(&mut self) {
for (_, vector) in self.temp_vectors.drain() {
self.memory_pool.return_f64_vector(vector);
}
self.comm_buffers.clear_buffers();
}
pub fn memory_stats(&self) -> WorkspaceMemoryStats {
let pool_stats = self.memory_pool.memory_stats();
let comm_stats = self.comm_buffers.memory_stats();
let temp_memory: usize = self
.temp_vectors
.values()
.map(|v| v.capacity() * std::mem::size_of::<f64>())
.sum();
WorkspaceMemoryStats {
temp_vectors_memory: temp_memory,
pool_memory: pool_stats.total_memory_bytes,
comm_memory: comm_stats.total_memory_bytes,
total_memory: temp_memory
+ pool_stats.total_memory_bytes
+ comm_stats.total_memory_bytes,
temp_vectors_count: self.temp_vectors.len(),
pool_stats,
comm_stats,
}
}
pub fn needs_cleanup(&self) -> bool {
let stats = self.memory_stats();
let limit_bytes = self.config.memory_limit_mb * 1024 * 1024;
stats.total_memory > limit_bytes
}
pub fn cleanup(&mut self) {
if self.config.aggressive_reuse {
for (_, vector) in self.temp_vectors.drain() {
self.memory_pool.return_f64_vector(vector);
}
} else {
self.memory_pool.clear();
}
}
pub fn optimize(&mut self) {
let active_sizes: std::collections::HashSet<_> =
self.temp_vectors.values().map(|v| v.capacity()).collect();
self.vector_sizes
.retain(|_, &mut size| active_sizes.contains(&size));
for vector in self.temp_vectors.values_mut() {
vector.shrink_to_fit();
}
}
}
#[derive(Debug, Clone)]
pub struct WorkspaceMemoryStats {
pub temp_vectors_memory: usize,
pub pool_memory: usize,
pub comm_memory: usize,
pub total_memory: usize,
pub temp_vectors_count: usize,
pub pool_stats: MemoryStats,
pub comm_stats: MemoryStats,
}
impl Default for SuperLuDistWorkspace {
fn default() -> Self {
Self::new()
}
}
pub struct SuperLuDistSolver {
options: SuperLuDistOptions,
data: Option<SuperLuDistData>,
refinement_engine: Option<RefinementEngine>,
workspace_config: WorkspaceConfig,
}
pub struct SuperLuDistBuilder {
options: SuperLuDistOptions,
workspace_config: WorkspaceConfig,
refinement_config: Option<RefinementConfig>,
residual_method: Option<ResidualMethod>,
}
impl SuperLuDistBuilder {
pub fn new() -> Self {
Self {
options: SuperLuDistOptions::default(),
workspace_config: WorkspaceConfig::default(),
refinement_config: None,
residual_method: None,
}
}
pub fn diagonal_pivot_threshold(mut self, threshold: f64) -> Self {
self.options.diagonal_pivot_threshold = threshold.clamp(0.0, 1.0);
self
}
pub fn column_permutation(mut self, perm: ColumnPermutation) -> Self {
self.options.column_permutation = perm;
self
}
pub fn row_permutation(mut self, perm: RowPermutation) -> Self {
self.options.row_permutation = perm;
self
}
pub fn iterative_refinement(mut self, refine: IterativeRefinement) -> Self {
self.options.iterative_refinement = refine;
self
}
pub fn print_level(mut self, level: u8) -> Self {
self.options.print_level = level;
self
}
pub fn replace_tiny_pivots(mut self, enable: bool) -> Self {
self.options.replace_tiny_pivots = enable;
self
}
pub fn static_pivoting(mut self, enable: bool) -> Self {
self.options.static_pivoting = enable;
self
}
pub fn process_grid(mut self, rows: usize, cols: usize) -> Self {
self.options.process_grid = Some((rows, cols));
self
}
pub fn process_grid_auto(mut self) -> Self {
self.options.process_grid = None;
self
}
pub fn panel_size(mut self, size: usize) -> Self {
self.options.panel_size = Some(size);
self
}
pub fn enable_3d_factorization(mut self, enable: bool, depth: Option<usize>) -> Self {
if enable {
let d = depth.unwrap_or(0);
if d < 2 {
#[cfg(feature = "logging")]
log::warn!(
"process_grid_3d_depth={d} is too small, falling back to 2D factorization"
);
self.options.enable_3d_factorization = false;
self.options.process_grid_3d_depth = None;
} else {
self.options.enable_3d_factorization = true;
self.options.process_grid_3d_depth = Some(d);
}
} else {
self.options.enable_3d_factorization = false;
self.options.process_grid_3d_depth = depth;
}
self
}
pub fn memory_tradeoff_factor(mut self, factor: f64) -> Self {
self.options.memory_tradeoff_factor = factor.max(0.1);
self
}
pub fn max_concurrent_panels(mut self, max_panels: usize) -> Self {
self.options.max_concurrent_panels = max_panels.max(1);
self
}
pub fn async_panel_updates(mut self, enable: bool) -> Self {
self.options.async_panel_updates = enable;
self
}
pub fn workspace_memory_limit(mut self, limit_mb: usize) -> Self {
self.workspace_config.memory_limit_mb = limit_mb;
self
}
pub fn aggressive_memory_reuse(mut self, enable: bool) -> Self {
self.workspace_config.aggressive_reuse = enable;
self
}
pub fn preallocation_strategy(mut self, strategy: PreallocationStrategy) -> Self {
self.workspace_config.preallocation_strategy = strategy;
self
}
pub fn refinement_config(mut self, config: RefinementConfig) -> Self {
self.refinement_config = Some(config);
self
}
pub fn residual_method(mut self, method: ResidualMethod) -> Self {
self.residual_method = Some(method);
self
}
pub fn build(self) -> SuperLuDistSolver {
let mut solver = SuperLuDistSolver {
options: self.options,
data: None,
refinement_engine: None,
workspace_config: self.workspace_config,
};
let _ = solver.options.validate(None);
if let Some(config) = self.refinement_config {
let method = self.residual_method.unwrap_or(ResidualMethod::Standard);
solver.refinement_engine = Some(RefinementEngine::new(config, method));
}
solver
}
}
impl Default for SuperLuDistBuilder {
fn default() -> Self {
Self::new()
}
}
impl SuperLuDistSolver {
pub fn new() -> Self {
Self {
options: SuperLuDistOptions::default(),
data: None,
refinement_engine: None,
workspace_config: WorkspaceConfig::default(),
}
}
pub fn with_options(options: SuperLuDistOptions) -> Self {
Self {
options,
data: None,
refinement_engine: None,
workspace_config: WorkspaceConfig::default(),
}
}
pub fn set_diagonal_pivot_threshold(&mut self, threshold: f64) -> &mut Self {
self.options.diagonal_pivot_threshold = threshold.clamp(0.0, 1.0);
self
}
pub fn set_column_permutation(&mut self, perm: ColumnPermutation) -> &mut Self {
self.options.column_permutation = perm;
self
}
pub fn set_iterative_refinement(&mut self, refine: IterativeRefinement) -> &mut Self {
self.options.iterative_refinement = refine;
self
}
pub fn set_print_level(&mut self, level: u8) -> &mut Self {
self.options.print_level = level;
self
}
pub fn set_panel_size(&mut self, size: usize) -> &mut Self {
self.options.panel_size = Some(size);
self
}
pub fn set_3d_factorization(&mut self, enable: bool, depth: Option<usize>) -> &mut Self {
if enable {
let d = depth.unwrap_or(0);
if d < 2 {
#[cfg(feature = "logging")]
log::warn!(
"process_grid_3d_depth={d} is too small, falling back to 2D factorization"
);
self.options.enable_3d_factorization = false;
self.options.process_grid_3d_depth = None;
} else {
self.options.enable_3d_factorization = true;
self.options.process_grid_3d_depth = Some(d);
}
} else {
self.options.enable_3d_factorization = false;
self.options.process_grid_3d_depth = depth;
}
self
}
pub fn set_memory_tradeoff(&mut self, factor: f64) -> &mut Self {
self.options.memory_tradeoff_factor = factor.max(0.1);
self
}
pub fn set_max_concurrent_panels(&mut self, max_panels: usize) -> &mut Self {
self.options.max_concurrent_panels = max_panels.max(1);
self
}
pub fn set_async_panel_updates(&mut self, enable: bool) -> &mut Self {
self.options.async_panel_updates = enable;
self
}
pub fn set_static_pivoting(&mut self, enable: bool) -> &mut Self {
self.options.static_pivoting = enable;
self
}
pub fn set_row_permutation(&mut self, perm: RowPermutation) -> &mut Self {
self.options.row_permutation = perm;
self
}
pub fn set_replace_tiny_pivots(&mut self, enable: bool) -> &mut Self {
self.options.replace_tiny_pivots = enable;
self
}
pub fn set_process_grid(&mut self, rows: usize, cols: usize) -> &mut Self {
self.options.process_grid = Some((rows, cols));
self
}
pub fn set_process_grid_auto(&mut self) -> &mut Self {
self.options.process_grid = None;
self
}
pub fn with_complete_options(mut self, options: SuperLuDistOptions) -> Self {
self.options = options;
self
}
pub fn builder() -> SuperLuDistBuilder {
SuperLuDistBuilder::new()
}
pub fn options(&self) -> &SuperLuDistOptions {
&self.options
}
pub fn enable_iterative_refinement(&mut self) -> &mut Self {
self.refinement_engine = Some(RefinementEngine::with_defaults());
self
}
pub fn set_refinement_config(&mut self, config: RefinementConfig) -> &mut Self {
if let Some(ref mut engine) = self.refinement_engine {
engine.set_config(config);
} else {
self.refinement_engine = Some(RefinementEngine::new(config, ResidualMethod::Standard));
}
self
}
pub fn set_residual_method(&mut self, method: ResidualMethod) -> &mut Self {
if let Some(ref mut engine) = self.refinement_engine {
engine.set_residual_method(method);
} else {
self.refinement_engine =
Some(RefinementEngine::new(RefinementConfig::default(), method));
}
self
}
pub fn disable_iterative_refinement(&mut self) -> &mut Self {
self.refinement_engine = None;
self
}
pub fn refinement_stats(&self) -> Option<&RefinementStats> {
self.refinement_engine
.as_ref()
.and_then(|engine| engine.last_stats())
}
pub fn set_workspace_memory_limit(&mut self, limit_mb: usize) -> &mut Self {
self.workspace_config.memory_limit_mb = limit_mb;
self
}
pub fn set_aggressive_memory_reuse(&mut self, enable: bool) -> &mut Self {
self.workspace_config.aggressive_reuse = enable;
self
}
pub fn set_preallocation_strategy(&mut self, strategy: PreallocationStrategy) -> &mut Self {
self.workspace_config.preallocation_strategy = strategy;
self
}
pub fn workspace_memory_stats(&self) -> Option<WorkspaceMemoryStats> {
self.data
.as_ref()
.and_then(|data| data.solve_workspace.as_ref())
.map(|workspace| workspace.workspace.memory_stats())
}
pub fn optimize_workspace(&mut self) -> Result<(), KError> {
if let Some(ref mut data) = self.data
&& let Some(ref mut solve_workspace) = data.solve_workspace
{
solve_workspace.workspace.optimize();
}
Ok(())
}
pub fn clear_workspace_temp_data(&mut self) -> Result<(), KError> {
if let Some(ref mut data) = self.data
&& let Some(ref mut solve_workspace) = data.solve_workspace
{
solve_workspace.workspace.clear_temp_data();
}
Ok(())
}
pub fn workspace_needs_cleanup(&self) -> bool {
self.data
.as_ref()
.and_then(|data| data.solve_workspace.as_ref())
.map(|workspace| workspace.workspace.needs_cleanup())
.unwrap_or(false)
}
fn setup_factorization(
&mut self,
matrix: &CsrMatrix<f64>,
comm: &UniverseComm,
) -> Result<(), KError> {
#[cfg(feature = "logging")]
let _guard = StageGuard::new("SuperLuDistSetup");
self.options.validate(Some(comm))?;
let process_grid = if let Some((prows, pcols)) = self.options.process_grid {
ProcessGrid::new_with_dims(comm, prows, pcols)?
} else {
ProcessGrid::new_auto(comm)?
};
let row_block_size = matrix.nrows().div_ceil(process_grid.prows).max(1);
let col_block_size = matrix.ncols().div_ceil(process_grid.pcols).max(1);
let distribution = BlockCyclicDistribution::new(
process_grid.clone(),
matrix.nrows(),
matrix.ncols(),
row_block_size,
col_block_size,
);
let local_matrix = matrix.clone();
let mut slu_data = SuperLuDistData {
process_grid,
distribution,
comm: comm.clone(),
options: self.options.clone(),
factored: false,
local_matrix: Some(local_matrix),
symbolic_factor: None,
numeric_factor: None,
solve_workspace: None,
};
#[cfg(feature = "logging")]
let _symbolic_guard = StageGuard::new("SuperLuDistSymbolic");
let symbolic = self.symbolic_factorization(&slu_data, comm)?;
slu_data.symbolic_factor = Some(symbolic);
#[cfg(feature = "logging")]
let _numeric_guard = StageGuard::new("SuperLuDistNumeric");
let numeric = self.numerical_factorization(&slu_data, comm)?;
slu_data.numeric_factor = Some(numeric);
let workspace = self.setup_solve_workspace(&slu_data)?;
slu_data.solve_workspace = Some(workspace);
slu_data.factored = true;
self.data = Some(slu_data);
Ok(())
}
fn large_diag_row_permutation(
&self,
matrix: &CsrMatrix<f64>,
comm: &UniverseComm,
) -> Vec<usize> {
let n = matrix.nrows();
let mut diag = vec![0.0f64; n];
let (row_start, row_end) = contiguous_partition(n, comm.rank(), comm.size());
let rp = matrix.row_ptr();
let cj = matrix.col_idx();
let vv = matrix.values();
for i in row_start..row_end {
let mut max_diag = 0.0;
for idx in rp[i]..rp[i + 1] {
if cj[idx] == i {
let val = vv[idx].abs();
if val > max_diag {
max_diag = val;
}
}
}
diag[i] = max_diag;
}
comm.allreduce_sum_slice(&mut diag);
let mut perm: Vec<usize> = (0..n).collect();
perm.sort_by(|&a, &b| {
diag[b]
.partial_cmp(&diag[a])
.unwrap_or(Ordering::Equal)
.then_with(|| a.cmp(&b))
});
perm
}
fn compute_scaling_factors(
&self,
matrix: &CsrMatrix<f64>,
comm: &UniverseComm,
) -> Result<(Vec<f64>, Vec<f64>), KError> {
let n = matrix.nrows();
let mut row_max = vec![0.0; n];
let mut col_max = vec![0.0; n];
let (row_start, row_end) = contiguous_partition(n, comm.rank(), comm.size());
let rp = matrix.row_ptr();
let cj = matrix.col_idx();
let vv = matrix.values();
for i in row_start..row_end {
let mut local_row_max = 0.0;
for idx in rp[i]..rp[i + 1] {
let j = cj[idx];
let val = vv[idx].abs();
if val > local_row_max {
local_row_max = val;
}
if val > col_max[j] {
col_max[j] = val;
}
}
row_max[i] = local_row_max;
}
comm.allreduce_sum_slice(&mut row_max);
self.allreduce_max_slice(comm, &mut col_max)?;
let row_scale = row_max
.into_iter()
.map(|v| if v > 0.0 { 1.0 / v } else { 1.0 })
.collect();
let col_scale = col_max
.into_iter()
.map(|v| if v > 0.0 { 1.0 / v } else { 1.0 })
.collect();
Ok((row_scale, col_scale))
}
fn allreduce_max_slice(
&self,
comm: &UniverseComm,
data: &mut [f64],
) -> Result<(), KError> {
if comm.size() <= 1 {
return Ok(());
}
#[cfg(feature = "mpi")]
if let Some(world) = comm.as_mpi() {
let mut recv = vec![0.0; data.len()];
let rc = unsafe {
mpi::ffi::MPI_Allreduce(
data.as_ptr() as *const c_void,
recv.as_mut_ptr() as *mut c_void,
recv.len() as i32,
mpi::ffi::RSMPI_DOUBLE,
mpi::ffi::RSMPI_MAX,
world.as_raw(),
)
};
if rc != 0 {
return Err(KError::SolveError(format!(
"MPI_Allreduce (MAX) failed: {rc}"
)));
}
data.copy_from_slice(&recv);
return Ok(());
}
Ok(())
}
#[allow(dead_code)]
fn distribute_matrix(
&self,
global_matrix: &CsrMatrix<f64>,
distribution: &BlockCyclicDistribution,
) -> Result<CsrMatrix<f64>, KError> {
let rp = global_matrix.row_ptr();
let cj = global_matrix.col_idx();
let vv = global_matrix.values();
let local_rows = distribution.local_rows;
let local_cols = distribution.local_cols;
let mut tmp_cols = vec![Vec::new(); local_rows];
let mut tmp_vals = vec![Vec::new(); local_rows];
for i in 0..global_matrix.nrows() {
for p in rp[i]..rp[i + 1] {
let j = cj[p];
if distribution.owner_of(i, j) == distribution.grid.my_rank {
let li = distribution.local_row_from_global(i).unwrap();
let lj = distribution.local_col_from_global(j).unwrap();
tmp_cols[li].push(lj);
tmp_vals[li].push(vv[p]);
}
}
}
let mut local_row_ptrs = Vec::with_capacity(local_rows + 1);
local_row_ptrs.push(0);
let mut local_col_indices = Vec::new();
let mut local_values = Vec::new();
for r in 0..local_rows {
local_col_indices.extend_from_slice(&tmp_cols[r]);
local_values.extend_from_slice(&tmp_vals[r]);
let last = *local_row_ptrs.last().unwrap();
local_row_ptrs.push(last + tmp_cols[r].len());
}
let local_matrix = CsrMatrix::from_csr(
local_rows,
local_cols,
local_row_ptrs,
local_col_indices,
local_values,
);
validate_local_csr(&local_matrix)?;
Ok(local_matrix)
}
fn symbolic_factorization(
&self,
data: &SuperLuDistData,
comm: &UniverseComm,
) -> Result<SymbolicFactorization, KError> {
#[cfg(feature = "logging")]
let _guard = StageGuard::new("SymbolicFactorization");
let n = data.distribution.global_rows;
let matrix = data
.local_matrix
.as_ref()
.ok_or_else(|| KError::SolveError("No local/global matrix".to_string()))?;
assert_eq!(matrix.nrows(), data.distribution.global_rows);
assert_eq!(matrix.ncols(), data.distribution.global_cols);
#[cfg(feature = "logging")]
log::debug!(
"Symbolic factorization: n={}, nnz={}, col_perm={:?}",
n,
matrix.nnz(),
self.options.column_permutation
);
let col_perm = match self.options.column_permutation {
ColumnPermutation::Natural => OrderingAlgorithms::natural_ordering(n),
ColumnPermutation::MmdAta => {
if comm.size() > 1 {
OrderingAlgorithms::mmd_ata_ordering_distributed(
matrix,
comm,
&data.distribution,
)?
} else {
OrderingAlgorithms::mmd_ata_ordering(matrix)
}
}
ColumnPermutation::Metis => OrderingAlgorithms::metis_ordering(matrix)?,
ColumnPermutation::ParMetis => {
OrderingAlgorithms::parmetis_ordering(matrix, comm, &data.distribution)?
}
ColumnPermutation::User => {
OrderingAlgorithms::natural_ordering(n)
}
};
let row_perm = match self.options.row_permutation {
RowPermutation::NoRowPerm => OrderingAlgorithms::natural_ordering(n),
RowPermutation::LargeDiag => self.large_diag_row_permutation(matrix, comm),
RowPermutation::User => {
OrderingAlgorithms::natural_ordering(n)
}
};
#[cfg(feature = "logging")]
log::debug!("Computing symbolic pattern with {n} x {n} matrix");
let l_pattern = SymbolicFactorizer::compute_symbolic_pattern(matrix, &col_perm, &row_perm)?;
for k in 0..n {
debug_assert!(l_pattern.contains_key(&(k, k)));
}
let mut u_pattern = HashMap::new();
for &(i, j) in l_pattern.keys() {
u_pattern.insert((j, i), true);
}
let etree = SymbolicFactorizer::build_elimination_tree(n, &l_pattern);
#[cfg(feature = "logging")]
log::debug!(
"Symbolic factorization completed: {} L entries, {} U entries",
l_pattern.len(),
u_pattern.len()
);
Ok(SymbolicFactorization {
col_perm,
row_perm,
etree,
l_pattern,
u_pattern,
})
}
fn numerical_factorization(
&self,
data: &SuperLuDistData,
comm: &UniverseComm,
) -> Result<NumericFactorization, KError> {
#[cfg(feature = "logging")]
let _guard = StageGuard::new("NumericalFactorization");
let n = data.distribution.global_rows;
let matrix = data
.local_matrix
.as_ref()
.ok_or_else(|| KError::SolveError("No local matrix available".to_string()))?;
let symbolic = data
.symbolic_factor
.as_ref()
.ok_or_else(|| KError::SolveError("No symbolic factorization available".to_string()))?;
let pivot_strategy = if self.options.static_pivoting {
PivotingStrategy::Static
} else if self.options.replace_tiny_pivots {
PivotingStrategy::ThresholdWithFallback
} else {
PivotingStrategy::Dynamic
};
let configured_panel_size = self.options.panel_size.unwrap_or(64).max(1);
let panel_size = std::cmp::min(configured_panel_size, n.max(1));
#[cfg(feature = "logging")]
log::debug!(
"Starting numerical factorization with panel size {panel_size}, pivot strategy {pivot_strategy:?}"
);
let mut panels = Vec::new();
let mut panel_factors = Vec::new();
let mut total_row_swaps = 0;
let mut tiny_pivots_replaced_total = 0usize;
let mut panels_with_replacements = 0usize;
let mut max_pivot_growth = 1.0;
for panel_start in (0..n).step_by(panel_size) {
let panel_end = std::cmp::min(panel_start + panel_size, n);
let mut panel_rows = Vec::new();
let rp = matrix.row_ptr();
let cj = matrix.col_idx();
for i in 0..matrix.nrows() {
let row_start = rp[i];
let row_end = rp[i + 1];
for idx in row_start..row_end {
let col = cj[idx];
if col >= panel_start && col < panel_end {
panel_rows.push(i);
break;
}
}
}
if panel_rows.is_empty() {
continue; }
let mut panel = Panel::from_sparse_columns(matrix, panel_start, panel_end, panel_rows);
match panel.factorize_lu(self.options.diagonal_pivot_threshold, pivot_strategy) {
Ok(factor) => {
total_row_swaps += factor.num_row_swaps;
tiny_pivots_replaced_total += factor.tiny_pivots_replaced;
if factor.tiny_pivots_replaced > 0 {
panels_with_replacements += 1;
}
for i in 0..panel.width.min(panel.height) {
let diag_val = panel.data[i * panel.height + i].abs();
if diag_val > max_pivot_growth {
max_pivot_growth = diag_val;
}
}
panel_factors.push(factor);
}
Err(e) => {
#[cfg(feature = "logging")]
log::error!("Panel factorization failed: {e}");
return Err(e);
}
}
panels.push(panel);
}
let global_row_perm = symbolic.row_perm.clone();
let global_col_perm = symbolic.col_perm.clone();
let (row_scale, col_scale) = self.compute_scaling_factors(matrix, comm)?;
let memory_usage = panels
.iter()
.map(|p| p.data.len() * std::mem::size_of::<f64>())
.sum::<usize>()
+ (global_row_perm.len() + global_col_perm.len()) * std::mem::size_of::<usize>();
let factor_stats = FactorizationStats {
num_panels: panels.len(),
total_row_swaps,
tiny_pivots_replaced: tiny_pivots_replaced_total,
max_pivot_growth,
condition_estimate: None, memory_usage,
};
#[cfg(feature = "logging")]
if self.options.enabled(1, 1) {
log::info!(
"Numerical factorization completed: {} panels, {} row swaps, {} panels replaced {} tiny pivots, max pivot growth {:.2e}",
factor_stats.num_panels,
factor_stats.total_row_swaps,
panels_with_replacements,
factor_stats.tiny_pivots_replaced,
factor_stats.max_pivot_growth
);
}
let bs = std::cmp::min(64, n / 4).max(1);
let nb = n.div_ceil(bs);
let mut lbg = vec![Vec::<usize>::new(); nb];
let mut ubg = vec![Vec::<usize>::new(); nb];
let add_edge = |graph: &mut [Vec<usize>], s: usize, t: usize| {
if s != t && !graph[s].contains(&t) {
graph[s].push(t);
}
};
for &(i, j) in symbolic.l_pattern.keys() {
let bi = i / bs;
let bj = j / bs;
if bj < bi {
add_edge(&mut lbg, bi, bj);
}
}
for &(i, j) in symbolic.u_pattern.keys() {
let bi = i / bs;
let bj = j / bs;
if bj > bi {
add_edge(&mut ubg, bi, bj);
}
}
for g in [&mut lbg, &mut ubg] {
for v in g.iter_mut() {
v.sort_unstable();
}
}
Ok(NumericFactorization {
n,
nnz: panels.iter().map(|p| p.data.len()).sum(),
panels,
panel_factors,
global_row_perm,
global_col_perm,
row_scale,
col_scale,
pivot_strategy,
pivot_threshold: self.options.diagonal_pivot_threshold,
replaced_tiny_pivots: tiny_pivots_replaced_total > 0,
factor_stats,
l_block_graph: lbg,
u_block_graph: ubg,
})
}
fn setup_solve_workspace(&self, data: &SuperLuDistData) -> Result<SolveWorkspace, KError> {
let n = data.distribution.global_rows;
let mut workspace_config = self.workspace_config.clone();
let block_size = self.options.panel_size.unwrap_or(64).max(1);
workspace_config.max_comm_buffer_size =
(block_size * data.process_grid.total_procs).max(1024);
let mut workspace = SuperLuDistWorkspace::with_config(workspace_config);
workspace.setup_for_problem(n, &data.process_grid, block_size)?;
let mut process_vectors = HashMap::new();
for p in 0..data.process_grid.total_procs {
let local_rows = data.distribution.local_rows_for_rank(p);
process_vectors.insert(p, vec![0.0; local_rows]);
}
let mut global_vectors = HashMap::new();
global_vectors.insert(
"solution_temp".to_string(),
vec![0.0; data.distribution.local_rows],
);
global_vectors.insert(
"rhs_temp".to_string(),
vec![0.0; data.distribution.local_rows],
);
global_vectors.insert(
"column_accum_temp".to_string(),
vec![0.0; data.distribution.local_cols],
);
global_vectors.insert(
"reduction_temp".to_string(),
vec![0.0; data.distribution.local_rows],
);
if data.process_grid.my_rank == 0 {
global_vectors.insert("permutation_temp".to_string(), vec![0.0; n]);
} else {
global_vectors.insert("permutation_temp".to_string(), Vec::new());
}
Ok(SolveWorkspace {
workspace,
process_vectors,
global_vectors,
})
}
fn solve_factored(
&mut self,
b: &Vec<f64>,
x: &mut Vec<f64>,
comm: &UniverseComm,
) -> Result<(), KError> {
let data = self
.data
.as_ref()
.ok_or_else(|| KError::SolveError("SuperLU_DIST not factored".to_string()))?;
if !data.factored {
return Err(KError::SolveError("Matrix not factored".to_string()));
}
#[cfg(feature = "logging")]
let _guard = StageGuard::new("SuperLuDistSolve");
x.clone_from_slice(b);
let numeric_factor = data
.numeric_factor
.as_ref()
.ok_or_else(|| KError::SolveError("No numeric factorization available".to_string()))?;
let comm_pattern = CommPattern::PointToPoint;
let overlap_comm = comm.size() > 1;
#[cfg(feature = "logging")]
if self.options.enabled(1, 1) {
log::info!(
"Starting distributed triangular solve with pattern {comm_pattern:?}, overlap_comm={overlap_comm}"
);
}
let row_perm = &numeric_factor.global_row_perm;
let mut permuted_b = vec![0.0; b.len()];
for (i, &perm_i) in row_perm.iter().enumerate() {
if perm_i < b.len() {
permuted_b[i] = b[perm_i];
}
}
#[cfg(feature = "superlu3d")]
let grid3d = if self.options.enable_3d_factorization {
self.options
.process_grid_3d_depth
.and_then(|d| ProcessGrid3D::from_2d_with_depth(&data.process_grid, d).ok())
} else {
None
};
let mut y = vec![0.0; x.len()];
DistributedTriangularSolver::forward_solve(
&permuted_b,
&mut y,
numeric_factor,
&data.distribution,
comm,
comm_pattern,
overlap_comm,
#[cfg(feature = "superlu3d")]
grid3d.as_ref(),
)?;
DistributedTriangularSolver::backward_solve(
&y,
x,
numeric_factor,
&data.distribution,
comm,
comm_pattern,
overlap_comm,
#[cfg(feature = "superlu3d")]
grid3d.as_ref(),
)?;
let col_perm = &numeric_factor.global_col_perm;
let mut permuted_x = vec![0.0; x.len()];
for (i, &perm_i) in col_perm.iter().enumerate() {
if i < x.len() && perm_i < permuted_x.len() {
permuted_x[perm_i] = x[i];
}
}
x.copy_from_slice(&permuted_x);
if !matches!(
self.options.iterative_refinement,
IterativeRefinement::NoRefine
) && let Some(ref mut engine) = self.refinement_engine
{
let data = self.data.as_ref().unwrap();
let local_matrix = data.local_matrix.as_ref().ok_or_else(|| {
KError::SolveError("Local matrix not available for refinement".to_string())
})?;
let _refinement_stats = engine.refine_solution(local_matrix, b, x, data, comm)?;
#[cfg(feature = "logging")]
if self.options.enabled(1, 1)
&& let Some(stats) = engine.last_stats()
{
log::info!(
"Iterative refinement completed: {} iterations, final residual: {:.2e}",
stats.iterations,
stats.final_residual_norm
);
}
}
#[cfg(feature = "logging")]
if self.options.enabled(1, 1) {
log::info!("Distributed triangular solve completed successfully");
}
Ok(())
}
pub fn destroy(&mut self) {
self.data = None;
self.refinement_engine = None;
}
pub fn clear_factors(&mut self) {
if let Some(d) = &mut self.data {
d.numeric_factor = None;
d.factored = false;
}
}
pub fn has_factors(&self) -> bool {
self.data.as_ref().map(|d| d.factored).unwrap_or(false)
}
}
impl Default for SuperLuDistSolver {
fn default() -> Self {
Self::new()
}
}
impl Solver<CsrMatrix<f64>> for SuperLuDistSolver {
type Error = KError;
fn setup(&mut self, a: &CsrMatrix<f64>, comm: &UniverseComm) -> Result<(), Self::Error> {
self.options.validate(Some(comm))?;
if self.data.is_none() {
self.setup_factorization(a, comm)?;
}
Ok(())
}
fn factor(&mut self, a: &CsrMatrix<f64>) -> Result<(), Self::Error> {
{
let data = self.data.as_ref().ok_or_else(|| {
KError::SolveError("call setup(&A, &comm) before factor(&A)".into())
})?;
if a.nrows() != data.distribution.global_rows
|| a.ncols() != data.distribution.global_cols
{
return Err(KError::InvalidInput(
"factor(): matrix dims changed since setup".into(),
));
}
}
let numeric = {
let data_ref = self.data.as_ref().ok_or_else(|| {
KError::SolveError("call setup(&A, &comm) before factor(&A)".into())
})?;
self.numerical_factorization(data_ref, &data_ref.comm)?
};
if let Some(data_mut) = self.data.as_mut() {
data_mut.numeric_factor = Some(numeric);
data_mut.factored = true;
}
Ok(())
}
fn solve(
&mut self,
b: &[f64],
x: &mut [f64],
comm: &UniverseComm,
) -> Result<SolveStats<f64>, Self::Error> {
if self.data.is_none() {
return Err(KError::SolveError(
"solve() called before setup()/factor()".into(),
));
}
let data = self.data.as_ref().unwrap();
if b.len() != data.distribution.global_rows {
return Err(KError::InvalidInput("RHS size mismatch".into()));
}
if x.len() != data.distribution.global_cols {
return Err(KError::InvalidInput("solution size mismatch".into()));
}
let mut xb = x.to_vec();
self.solve_factored(&b.to_vec(), &mut xb, comm)?;
x.copy_from_slice(&xb);
Ok(SolveStats::new(1, 0.0, ConvergedReason::ConvergedAtol))
}
fn reuse_factorization(&self) -> bool {
self.has_factors()
}
}
impl LinearSolver<CsrMatrix<f64>, Vec<f64>> for SuperLuDistSolver {
type Error = KError;
type Scalar = f64;
fn solve(
&mut self,
a: &CsrMatrix<f64>,
pc: Option<
&(dyn crate::preconditioner::legacy::Preconditioner<CsrMatrix<f64>, Vec<f64>> + '_),
>,
b: &Vec<f64>,
x: &mut Vec<f64>,
pc_side: crate::preconditioner::PcSide,
comm: &crate::parallel::UniverseComm,
monitors: Option<&[Box<MonitorCallback<Self::Scalar>>]>,
_work: Option<&mut crate::context::ksp_context::Workspace>,
) -> Result<SolveStats<f64>, KError> {
let _ = pc;
let _ = pc_side;
let _ = monitors;
self.setup(a, comm)?;
self.factor(a)?;
<SuperLuDistSolver as Solver<CsrMatrix<f64>>>::solve(
self,
b.as_slice(),
x.as_mut_slice(),
comm,
)
}
}
#[cfg(feature = "superlu_dist")]
pub fn solve(
a: &CsrMatrix<f64>,
b: &[f64],
x: &mut [f64],
comm: &UniverseComm,
) -> Result<(), KError> {
let mut solver = SuperLuDistSolver::new();
let mut x_vec = x.to_vec();
let b_vec = b.to_vec();
let _ = crate::solver::legacy::LinearSolver::solve(
&mut solver,
a,
None,
&b_vec,
&mut x_vec,
crate::preconditioner::PcSide::Left,
comm,
None,
None,
)?;
x.copy_from_slice(&x_vec);
Ok(())
}
#[cfg(not(feature = "superlu_dist"))]
pub fn solve(
_a: &crate::matrix::sparse::CsrMatrix<f64>,
_b: &[f64],
_x: &mut [f64],
_comm: &UniverseComm,
) -> Result<(), KError> {
Err(KError::SolveError(
"superlu_dist feature not enabled".into(),
))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::parallel::NoComm;
use faer::MatMut;
use faer::linalg::solvers::{FullPivLu, SolveCore};
#[test]
fn test_superlu_dist_creation() {
let solver = SuperLuDistSolver::new();
assert!(solver.data.is_none());
assert_eq!(solver.options.print_level, 0);
assert_eq!(solver.options.diagonal_pivot_threshold, 1.0);
}
#[test]
fn test_superlu_dist_options() {
let mut solver = SuperLuDistSolver::new();
solver
.set_diagonal_pivot_threshold(0.5)
.set_column_permutation(ColumnPermutation::Metis)
.set_iterative_refinement(IterativeRefinement::Single)
.set_print_level(1);
assert_eq!(solver.options.diagonal_pivot_threshold, 0.5);
assert_eq!(solver.options.column_permutation, ColumnPermutation::Metis);
assert_eq!(
solver.options.iterative_refinement,
IterativeRefinement::Single
);
assert_eq!(solver.options.print_level, 1);
}
#[test]
fn test_process_grid_determination() {
assert_eq!(ProcessGrid::determine_optimal_grid(1), (1, 1));
assert_eq!(ProcessGrid::determine_optimal_grid(4), (2, 2));
assert_eq!(ProcessGrid::determine_optimal_grid(6), (2, 3));
assert_eq!(ProcessGrid::determine_optimal_grid(8), (2, 4));
assert_eq!(ProcessGrid::determine_optimal_grid(16), (4, 4));
}
#[test]
fn test_process_grid_creation() {
let comm = UniverseComm::NoComm(NoComm);
let grid = ProcessGrid::new_auto(&comm).unwrap();
assert_eq!(grid.prows, 1);
assert_eq!(grid.pcols, 1);
assert_eq!(grid.my_prow, 0);
assert_eq!(grid.my_pcol, 0);
assert_eq!(grid.my_rank, 0);
assert_eq!(grid.total_procs, 1);
}
#[test]
fn test_block_cyclic_distribution() {
let comm = UniverseComm::NoComm(NoComm);
let grid = ProcessGrid::new_auto(&comm).unwrap();
let distribution = BlockCyclicDistribution::new(
grid, 10, 10, 4, 4, );
assert_eq!(distribution.global_rows, 10);
assert_eq!(distribution.global_cols, 10);
assert_eq!(distribution.local_rows, 10); assert_eq!(distribution.local_cols, 10); }
#[test]
fn test_global_to_local_conversion() {
let comm = UniverseComm::NoComm(NoComm);
let grid = ProcessGrid::new_auto(&comm).unwrap();
let distribution = BlockCyclicDistribution::new(grid, 8, 8, 4, 4);
assert_eq!(distribution.global_to_local_row(0), Some(0));
assert_eq!(distribution.global_to_local_row(3), Some(3));
assert_eq!(distribution.global_to_local_row(7), Some(7));
assert_eq!(distribution.global_to_local_col(0), Some(0));
assert_eq!(distribution.global_to_local_col(3), Some(3));
assert_eq!(distribution.global_to_local_col(7), Some(7));
}
#[test]
fn distribute_handles_empty() {
let a = CsrMatrix::from_csr(0, 0, vec![0], vec![], vec![]);
let grid = ProcessGrid {
prows: 1,
pcols: 1,
my_prow: 0,
my_pcol: 0,
my_rank: 0,
total_procs: 1,
};
let dist = BlockCyclicDistribution::new(grid, 0, 0, 4, 4);
let local = SuperLuDistSolver::new()
.distribute_matrix(&a, &dist)
.unwrap();
assert_eq!(local.nrows(), 0);
assert_eq!(local.ncols(), 0);
assert_eq!(local.row_ptr(), &[0]);
}
#[test]
fn distribute_non_square_and_small_blocks() {
let a = CsrMatrix::from_csr(
5,
3,
vec![0, 1, 2, 2, 3, 3],
vec![0, 1, 2],
vec![1.0, 2.0, 3.0],
);
let grid = ProcessGrid {
prows: 2,
pcols: 2,
my_prow: 0,
my_pcol: 0,
my_rank: 0,
total_procs: 4,
};
let dist = BlockCyclicDistribution::new(grid, 5, 3, 2, 2);
let local = SuperLuDistSolver::new()
.distribute_matrix(&a, &dist)
.unwrap();
assert!(validate_local_csr(&local).is_ok());
for &c in local.col_idx() {
assert!(c < local.ncols());
}
}
#[test]
fn test_graph_creation() {
let matrix = CsrMatrix::from_csr(
3,
3,
vec![0, 2, 4, 6],
vec![0, 1, 0, 2, 1, 2],
vec![2.0, -1.0, -1.0, 2.0, -1.0, 2.0],
);
let graph = Graph::from_matrix_pattern(&matrix);
assert_eq!(graph.adj[0], vec![1]); assert_eq!(graph.adj[1], vec![0, 2]); assert_eq!(graph.adj[2], vec![1]); }
#[test]
fn test_natural_ordering() {
let perm = OrderingAlgorithms::natural_ordering(5);
assert_eq!(perm, vec![0, 1, 2, 3, 4]);
}
#[test]
fn test_amd_ordering() {
let matrix = CsrMatrix::from_csr(
4,
4,
vec![0, 2, 4, 6, 8],
vec![0, 1, 1, 2, 2, 3, 0, 3],
vec![1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0],
);
let perm = OrderingAlgorithms::amd_ordering(&matrix);
assert_eq!(perm.len(), 4);
let mut sorted_perm = perm.clone();
sorted_perm.sort();
assert_eq!(sorted_perm, vec![0, 1, 2, 3]);
}
#[test]
fn test_symbolic_factorization_pattern() {
let matrix = CsrMatrix::from_csr(
3,
3,
vec![0, 2, 4, 6],
vec![0, 1, 1, 2, 0, 2],
vec![2.0, -1.0, 2.0, -1.0, -1.0, 2.0],
);
let col_perm = vec![0, 1, 2];
let row_perm = vec![0, 1, 2];
let pattern =
SymbolicFactorizer::compute_symbolic_pattern(&matrix, &col_perm, &row_perm).unwrap();
assert!(pattern.contains_key(&(0, 0)));
assert!(pattern.contains_key(&(1, 1)));
assert!(pattern.contains_key(&(2, 2)));
}
#[test]
fn test_elimination_tree_construction() {
let n = 3;
let mut l_pattern = HashMap::new();
l_pattern.insert((0, 0), true);
l_pattern.insert((1, 0), true);
l_pattern.insert((1, 1), true);
l_pattern.insert((2, 0), true);
l_pattern.insert((2, 1), true);
l_pattern.insert((2, 2), true);
let etree = SymbolicFactorizer::build_elimination_tree(n, &l_pattern);
assert_eq!(etree.parent.len(), n);
assert_eq!(etree.children.len(), n);
}
#[test]
fn test_enhanced_symbolic_factorization() {
let matrix = CsrMatrix::from_csr(
3,
3,
vec![0, 2, 4, 6],
vec![0, 1, 1, 2, 0, 2],
vec![2.0, -1.0, 2.0, -1.0, -1.0, 2.0],
);
let comm = UniverseComm::NoComm(NoComm);
let grid = ProcessGrid::new_auto(&comm).unwrap();
let distribution = BlockCyclicDistribution::new(grid, 3, 3, 2, 2);
let mut solver = SuperLuDistSolver::new();
solver.set_column_permutation(ColumnPermutation::MmdAta);
let slu_data = SuperLuDistData {
process_grid: distribution.grid.clone(),
distribution,
comm: UniverseComm::NoComm(NoComm),
options: solver.options.clone(),
factored: false,
local_matrix: Some(matrix),
symbolic_factor: None,
numeric_factor: None,
solve_workspace: None,
};
let comm = UniverseComm::NoComm(NoComm);
let symbolic = solver.symbolic_factorization(&slu_data, &comm).unwrap();
assert_eq!(symbolic.col_perm.len(), 3);
assert_eq!(symbolic.row_perm.len(), 3);
assert!(!symbolic.l_pattern.is_empty());
assert!(!symbolic.u_pattern.is_empty());
}
#[test]
fn test_panel_creation() {
let matrix = CsrMatrix::from_csr(
4,
4,
vec![0, 2, 4, 6, 8],
vec![0, 1, 1, 2, 2, 3, 0, 3],
vec![2.0, -1.0, 2.0, -1.0, 2.0, -1.0, -1.0, 2.0],
);
let row_indices = vec![0, 1, 2, 3];
let panel = Panel::from_sparse_columns(&matrix, 0, 2, row_indices);
assert_eq!(panel.width, 2);
assert_eq!(panel.height, 4);
assert_eq!(panel.col_start, 0);
assert_eq!(panel.data.len(), 8);
assert_eq!(panel.data[0], 2.0); assert_eq!(panel.data[1], 0.0); assert_eq!(panel.data[4], -1.0); assert_eq!(panel.data[5], 2.0); }
#[test]
fn test_panel_factorization_static() {
let mut panel = Panel {
width: 2,
height: 2,
data: vec![2.0, 1.0, 1.0, 3.0], row_indices: vec![0, 1],
col_start: 0,
};
let result = panel.factorize_lu(1e-12, PivotingStrategy::Static).unwrap();
assert_eq!(result.pivot_strategy, PivotingStrategy::Static);
assert_eq!(result.num_row_swaps, 0); assert!(!result.is_singular);
assert_ne!(panel.data, vec![2.0, 1.0, 1.0, 3.0]);
}
#[test]
fn test_panel_factorization_dynamic() {
let mut panel = Panel {
width: 2,
height: 2,
data: vec![1.0, 3.0, 2.0, 1.0], row_indices: vec![0, 1],
col_start: 0,
};
let result = panel
.factorize_lu(1e-12, PivotingStrategy::Dynamic)
.unwrap();
assert_eq!(result.pivot_strategy, PivotingStrategy::Dynamic);
assert!(result.num_row_swaps > 0 || !result.is_singular);
}
#[test]
fn test_panel_factorization_tiny_pivot() {
let mut panel = Panel {
width: 2,
height: 2,
data: vec![1e-15, 1.0, 1.0, 3.0], row_indices: vec![0, 1],
col_start: 0,
};
let result = panel.factorize_lu(1e-12, PivotingStrategy::Static).unwrap();
assert!(result.is_singular); assert_eq!(result.pivot_strategy, PivotingStrategy::Static);
assert!(panel.data[0].abs() >= 1e-12);
}
#[test]
fn test_threshold_with_fallback() {
let mut panel = Panel {
width: 2,
height: 2,
data: vec![1e-15, 1.0, 1.0, 3.0], row_indices: vec![0, 1],
col_start: 0,
};
let result = panel
.factorize_lu(1e-12, PivotingStrategy::ThresholdWithFallback)
.unwrap();
assert_eq!(result.pivot_strategy, PivotingStrategy::Dynamic);
}
#[test]
fn test_numerical_factorization_integration() {
let matrix = CsrMatrix::from_csr(
3,
3,
vec![0, 2, 4, 6],
vec![0, 1, 1, 2, 0, 2],
vec![4.0, -1.0, 4.0, -1.0, -1.0, 4.0],
);
let comm = UniverseComm::NoComm(NoComm);
let grid = ProcessGrid::new_auto(&comm).unwrap();
let distribution = BlockCyclicDistribution::new(grid, 3, 3, 2, 2);
let mut solver = SuperLuDistSolver::new();
solver.set_static_pivoting(true).set_panel_size(2);
let symbolic = SymbolicFactorization {
col_perm: vec![0, 1, 2],
row_perm: vec![0, 1, 2],
etree: EliminationTree {
parent: vec![3, 3, 3],
children: vec![vec![], vec![], vec![], vec![0, 1, 2]],
post_order: vec![0, 1, 2],
},
l_pattern: [(0, 0), (1, 1), (2, 2)]
.iter()
.map(|&k| (k, true))
.collect(),
u_pattern: [(0, 0), (1, 1), (2, 2)]
.iter()
.map(|&k| (k, true))
.collect(),
};
let slu_data = SuperLuDistData {
process_grid: distribution.grid.clone(),
distribution,
comm: UniverseComm::NoComm(NoComm),
options: solver.options.clone(),
factored: false,
local_matrix: Some(matrix),
symbolic_factor: Some(symbolic),
numeric_factor: None,
solve_workspace: None,
};
let numeric = solver
.numerical_factorization(&slu_data, &slu_data.comm)
.unwrap();
assert_eq!(numeric.n, 3);
assert_eq!(numeric.pivot_strategy, PivotingStrategy::Static);
assert!(!numeric.panels.is_empty());
assert_eq!(numeric.panels.len(), numeric.panel_factors.len());
assert_eq!(numeric.global_row_perm.len(), 3);
assert_eq!(numeric.global_col_perm.len(), 3);
assert!(numeric.factor_stats.num_panels > 0);
assert!(numeric.factor_stats.memory_usage > 0);
assert!(numeric.factor_stats.max_pivot_growth >= 1.0);
}
#[test]
fn test_3d_factorization_options() {
let mut solver = SuperLuDistSolver::new();
solver
.set_3d_factorization(true, Some(2))
.set_memory_tradeoff(2.5)
.set_max_concurrent_panels(4)
.set_async_panel_updates(true);
let options = solver.options();
assert!(options.enable_3d_factorization);
assert_eq!(options.process_grid_3d_depth, Some(2));
assert_eq!(options.memory_tradeoff_factor, 2.5);
assert_eq!(options.max_concurrent_panels, 4);
assert!(options.async_panel_updates);
}
#[test]
fn test_pivoting_strategies() {
let mut solver = SuperLuDistSolver::new();
solver.set_static_pivoting(true);
assert!(solver.options().static_pivoting);
solver.set_diagonal_pivot_threshold(0.1);
assert_eq!(solver.options().diagonal_pivot_threshold, 0.1);
solver.set_panel_size(32);
assert_eq!(solver.options().panel_size, Some(32));
}
#[test]
fn test_triangular_solve_data_creation() {
let comm = UniverseComm::NoComm(NoComm);
let grid = ProcessGrid::new_auto(&comm).unwrap();
let distribution = BlockCyclicDistribution::new(grid, 8, 8, 2, 2);
let numeric_factor = NumericFactorization {
n: 8,
nnz: 16,
panels: vec![],
panel_factors: vec![],
global_row_perm: (0..8).collect(),
global_col_perm: (0..8).collect(),
row_scale: vec![1.0; 8],
col_scale: vec![1.0; 8],
pivot_strategy: PivotingStrategy::Static,
pivot_threshold: 1e-12,
replaced_tiny_pivots: false,
factor_stats: FactorizationStats {
num_panels: 0,
total_row_swaps: 0,
tiny_pivots_replaced: 0,
max_pivot_growth: 1.0,
condition_estimate: None,
memory_usage: 0,
},
l_block_graph: vec![vec![], vec![]],
u_block_graph: vec![vec![], vec![]],
};
let solve_data =
TriangularSolveData::new(8, 4, &distribution, &numeric_factor, vec![vec![], vec![]]);
assert_eq!(solve_data.block_owners.len(), 2); assert_eq!(solve_data.dependency_graph.len(), 2);
assert!(!solve_data.comm_buffer.is_empty());
}
#[test]
fn diag_owner_is_mod_coords() {
let grid = ProcessGrid {
prows: 2,
pcols: 3,
my_prow: 0,
my_pcol: 0,
my_rank: 0,
total_procs: 6,
};
let dist = BlockCyclicDistribution::new(grid, 128, 128, 4, 4);
for k in 0..12 {
let prow = k % dist.grid.prows;
let pcol = k % dist.grid.pcols;
let expect = dist.grid.coords_to_rank(prow, pcol);
assert_eq!(dist.owner_rank_of_diag_block(k), expect);
}
}
#[test]
fn block_sizes_are_exact() {
let grid = ProcessGrid {
prows: 2,
pcols: 2,
my_prow: 0,
my_pcol: 0,
my_rank: 0,
total_procs: 4,
};
let dist = BlockCyclicDistribution::new(grid, 130, 130, 64, 64);
let n = 130usize;
let bs = 64usize;
let nf = NumericFactorization {
n,
nnz: 0,
panels: vec![],
panel_factors: vec![],
global_row_perm: vec![],
global_col_perm: vec![],
row_scale: vec![],
col_scale: vec![],
pivot_strategy: PivotingStrategy::Static,
pivot_threshold: 1.0,
replaced_tiny_pivots: false,
factor_stats: FactorizationStats {
num_panels: 0,
total_row_swaps: 0,
tiny_pivots_replaced: 0,
max_pivot_growth: 1.0,
condition_estimate: None,
memory_usage: 0,
},
l_block_graph: vec![vec![], vec![], vec![]],
u_block_graph: vec![vec![], vec![], vec![]],
};
let t = TriangularSolveData::new(n, bs, &dist, &nf, vec![vec![], vec![], vec![]]);
assert_eq!(t.block_sizes, vec![64, 64, 2]);
}
#[test]
fn test_communication_patterns() {
assert_eq!(CommPattern::BinaryTree, CommPattern::BinaryTree);
assert_ne!(CommPattern::BinaryTree, CommPattern::PointToPoint);
let request = CommRequest::new(
1, 0, 1, 100, CommType::Send,
64, );
assert_eq!(request.request_id, 1);
assert_eq!(request.comm_type, CommType::Send);
}
#[test]
fn l_block_graph_coarsens_symbolic() {
use std::collections::HashMap;
let mut lpat = HashMap::new();
for i in 0..6 {
for j in 0..=i {
lpat.insert((i, j), true);
}
}
for &(i, j) in &[(4, 1), (5, 1), (4, 0), (5, 0)] {
lpat.remove(&(i, j));
}
let symbolic = SymbolicFactorization {
col_perm: (0..6).collect(),
row_perm: (0..6).collect(),
etree: EliminationTree {
parent: vec![6; 6],
children: vec![vec![]; 6],
post_order: vec![],
},
l_pattern: lpat,
u_pattern: HashMap::new(),
};
let bs = 2;
let nb = 3;
let mut lbg = vec![Vec::<usize>::new(); nb];
let add_edge = |g: &mut [Vec<usize>], s: usize, t: usize| {
if s != t && !g[s].contains(&t) {
g[s].push(t);
}
};
for (&(i, j), _) in &symbolic.l_pattern {
let bi = i / bs;
let bj = j / bs;
if bj < bi {
add_edge(&mut lbg, bi, bj);
}
}
for v in lbg.iter_mut() {
v.sort_unstable();
}
assert_eq!(lbg[0], Vec::<usize>::new());
assert_eq!(lbg[1], vec![0usize]);
assert_eq!(lbg[2], vec![1usize]);
}
#[test]
fn test_local_triangular_solve_l() {
let panel = Panel {
width: 2,
height: 2,
data: vec![1.0, 2.0, 4.0, 5.0],
row_indices: vec![0, 1],
col_start: 0,
};
let factors = vec![panel];
let mut x = vec![3.0, 11.0];
DistributedTriangularSolver::solve_local_l_block(&mut x, &factors, 0).unwrap();
assert!((x[0] - 3.0).abs() < 1e-10);
assert!((x[1] - 5.0).abs() < 1e-10);
}
#[test]
fn test_local_triangular_solve_u() {
let panel = Panel {
width: 2,
height: 2,
data: vec![2.0, 0.0, 1.0, 3.0],
row_indices: vec![0, 1],
col_start: 0,
};
let factors = vec![panel];
let mut x = vec![4.0, 6.0];
DistributedTriangularSolver::solve_local_u_block(&mut x, &factors, 0).unwrap();
assert!((x[0] - 1.0).abs() < 1e-10);
assert!((x[1] - 2.0).abs() < 1e-10);
}
#[test]
fn test_block_update_application() {
let panel = Panel {
width: 2,
height: 2,
data: vec![1.0, 3.0, 2.0, 4.0],
row_indices: vec![0, 1],
col_start: 0,
};
let factors = vec![panel];
let mut x_block = vec![5.0, 7.0];
let update_data = vec![1.0, 1.0];
DistributedTriangularSolver::apply_block_update(&mut x_block, &update_data, 0, 0, &factors)
.unwrap();
assert!((x_block[0] - 4.0).abs() < 1e-10);
assert!((x_block[1] - 4.0).abs() < 1e-10);
}
#[test]
fn test_distributed_solve_integration() {
let matrix = CsrMatrix::from_csr(
4,
4,
vec![0, 2, 4, 6, 8],
vec![0, 1, 1, 2, 2, 3, 0, 3],
vec![4.0, -1.0, 4.0, -1.0, 4.0, -1.0, -1.0, 4.0],
);
let b = vec![1.0, 2.0, 3.0, 4.0];
let mut x = vec![0.0; 4];
let mut solver = SuperLuDistSolver::new();
solver
.set_async_panel_updates(true)
.set_3d_factorization(false, None)
.set_max_concurrent_panels(2);
let comm = UniverseComm::NoComm(NoComm);
let stats = crate::solver::legacy::LinearSolver::solve(
&mut solver,
&matrix,
None,
&b,
&mut x,
crate::preconditioner::PcSide::Left,
&comm,
None,
None,
)
.unwrap();
assert_eq!(stats.iterations, 1);
assert!(matches!(stats.reason, ConvergedReason::ConvergedAtol));
assert!(x.iter().all(|val: &f64| val.is_finite()));
}
#[test]
fn test_communication_overlap_options() {
let mut solver = SuperLuDistSolver::new();
solver.set_async_panel_updates(true);
assert!(solver.options().async_panel_updates);
solver.set_max_concurrent_panels(8);
assert_eq!(solver.options().max_concurrent_panels, 8);
solver.set_3d_factorization(true, Some(4));
assert!(solver.options().enable_3d_factorization);
assert_eq!(solver.options().process_grid_3d_depth, Some(4));
solver.set_memory_tradeoff(3.0);
assert_eq!(solver.options().memory_tradeoff_factor, 3.0);
}
#[test]
fn test_superlu_dist_simple_solve() {
let matrix = CsrMatrix::identity(5);
let b = vec![1.0, 2.0, 3.0, 4.0, 5.0];
let mut x = vec![0.0; 5];
let mut solver = SuperLuDistSolver::new();
let comm = UniverseComm::NoComm(NoComm);
let stats = crate::solver::legacy::LinearSolver::solve(
&mut solver,
&matrix,
None,
&b,
&mut x,
crate::preconditioner::PcSide::Left,
&comm,
None,
None,
)
.unwrap();
assert_eq!(x, b);
assert_eq!(stats.iterations, 1);
assert!(matches!(stats.reason, ConvergedReason::ConvergedAtol));
}
#[test]
#[ignore]
fn test_superlu_dist_spd_solve() {
let matrix = CsrMatrix::from_csr(
3,
3,
vec![0, 3, 6, 9],
vec![0, 1, 2, 0, 1, 2, 0, 1, 2],
vec![4.0, 1.0, 0.0, 1.0, 3.0, 1.0, 0.0, 1.0, 2.0],
);
let b = vec![1.0, 2.0, 3.0];
let mut x = vec![0.0; 3];
let mut solver = SuperLuDistSolver::new();
solver
.set_column_permutation(ColumnPermutation::Natural)
.set_row_permutation(RowPermutation::NoRowPerm);
let comm = UniverseComm::NoComm(NoComm);
let _ = crate::solver::legacy::LinearSolver::solve(
&mut solver,
&matrix,
None,
&b,
&mut x,
crate::preconditioner::PcSide::Left,
&comm,
None,
None,
)
.unwrap();
let mut x_ref = b.clone();
let a_dense = matrix.to_dense().unwrap();
let lu = FullPivLu::new(a_dense.as_ref());
let x_mat = MatMut::from_column_major_slice_mut(&mut x_ref, 3, 1);
lu.solve_in_place_with_conj(faer::Conj::No, x_mat);
for i in 0..3 {
assert!((x[i] - x_ref[i]).abs() < 1e-8);
}
}
#[test]
#[ignore]
fn test_superlu_dist_indefinite_solve() {
let matrix = CsrMatrix::from_csr(
2,
2,
vec![0, 2, 4],
vec![0, 1, 0, 1],
vec![2.0, 1.0, 1.0, -1.0],
);
let b = vec![1.0, 0.0];
let mut x = vec![0.0; 2];
let mut solver = SuperLuDistSolver::new();
solver
.set_column_permutation(ColumnPermutation::Natural)
.set_row_permutation(RowPermutation::NoRowPerm);
let comm = UniverseComm::NoComm(NoComm);
let _ = crate::solver::legacy::LinearSolver::solve(
&mut solver,
&matrix,
None,
&b,
&mut x,
crate::preconditioner::PcSide::Left,
&comm,
None,
None,
)
.unwrap();
let mut x_ref = b.clone();
let a_dense = matrix.to_dense().unwrap();
let lu = FullPivLu::new(a_dense.as_ref());
let x_mat = MatMut::from_column_major_slice_mut(&mut x_ref, 2, 1);
lu.solve_in_place_with_conj(faer::Conj::No, x_mat);
for i in 0..2 {
assert!((x[i] - x_ref[i]).abs() < 1e-8);
}
}
fn make_spd6() -> CsrMatrix<f64> {
let n = 6;
let mut row_ptr = vec![0];
let mut col = Vec::new();
let mut val = Vec::new();
for i in 0..n {
if i > 0 {
col.push(i - 1);
val.push(-1.0);
}
col.push(i);
val.push(4.0);
if i + 1 < n {
col.push(i + 1);
val.push(-1.0);
}
row_ptr.push(col.len());
}
CsrMatrix::from_csr(n, n, row_ptr, col, val)
}
#[test]
#[ignore]
fn test_superlu_dist_random_spd() {
let matrix = make_spd6();
let b = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0];
let mut x = vec![0.0; 6];
let mut solver = SuperLuDistSolver::new();
let comm = UniverseComm::NoComm(NoComm);
let _ = crate::solver::legacy::LinearSolver::solve(
&mut solver,
&matrix,
None,
&b,
&mut x,
crate::preconditioner::PcSide::Left,
&comm,
None,
None,
)
.unwrap();
let mut x_ref = b.clone();
let a_dense = matrix.to_dense().unwrap();
let lu = FullPivLu::new(a_dense.as_ref());
let x_mat = MatMut::from_column_major_slice_mut(&mut x_ref, 6, 1);
lu.solve_in_place_with_conj(faer::Conj::No, x_mat);
for i in 0..6 {
assert!((x[i] - x_ref[i]).abs() < 1e-8);
}
}
#[test]
fn test_superlu_dist_tiny_pivot_replacement() {
let matrix = CsrMatrix::from_csr(
2,
2,
vec![0, 2, 4],
vec![0, 1, 0, 1],
vec![1e-12, 1.0, 1.0, 1.0],
);
let b = vec![1.0, 2.0];
let mut x = vec![0.0; 2];
let mut solver = SuperLuDistSolver::new();
solver
.set_replace_tiny_pivots(true)
.set_static_pivoting(true)
.set_diagonal_pivot_threshold(1e-8);
let comm = UniverseComm::NoComm(NoComm);
let _ = crate::solver::legacy::LinearSolver::solve(
&mut solver,
&matrix,
None,
&b,
&mut x,
crate::preconditioner::PcSide::Left,
&comm,
None,
None,
)
.unwrap();
let stats = solver
.data
.as_ref()
.unwrap()
.numeric_factor
.as_ref()
.unwrap()
.factor_stats
.tiny_pivots_replaced;
assert!(stats > 0);
}
#[test]
fn test_invalid_input_dimensions() {
let matrix =
CsrMatrix::from_csr(3, 3, vec![0, 1, 2, 3], vec![0, 1, 2], vec![1.0, 1.0, 1.0]);
let b = vec![1.0, 2.0]; let mut x = vec![0.0; 3];
let mut solver = SuperLuDistSolver::new();
let comm = UniverseComm::NoComm(NoComm);
let result = crate::solver::legacy::LinearSolver::solve(
&mut solver,
&matrix,
None,
&b,
&mut x,
crate::preconditioner::PcSide::Left,
&comm,
None,
None,
);
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), KError::InvalidInput(_)));
}
#[test]
fn test_solver_reuse() {
let matrix = CsrMatrix::from_csr(2, 2, vec![0, 1, 2], vec![0, 1], vec![2.0, 3.0]);
let mut solver = SuperLuDistSolver::new();
let comm = UniverseComm::NoComm(NoComm);
let b1 = vec![2.0, 3.0];
let mut x1 = vec![0.0; 2];
let _stats1 = crate::solver::legacy::LinearSolver::solve(
&mut solver,
&matrix,
None,
&b1,
&mut x1,
crate::preconditioner::PcSide::Left,
&comm,
None,
None,
)
.unwrap();
assert!(solver.data.is_some());
let b2 = vec![4.0, 6.0];
let mut x2 = vec![0.0; 2];
let _stats2 = crate::solver::legacy::LinearSolver::solve(
&mut solver,
&matrix,
None,
&b2,
&mut x2,
crate::preconditioner::PcSide::Left,
&comm,
None,
None,
)
.unwrap();
assert!(solver.data.is_some());
}
#[test]
fn test_refinement_config() {
let config = RefinementConfig {
max_iterations: 10,
tolerance: 1e-14,
relative_tolerance: 1e-8,
min_improvement_factor: 0.8,
};
let mut solver = SuperLuDistSolver::new();
solver.set_refinement_config(config.clone());
assert!(solver.refinement_engine.is_some());
if let Some(ref engine) = solver.refinement_engine {
assert_eq!(engine.config.max_iterations, 10);
assert_eq!(engine.config.tolerance, 1e-14);
assert_eq!(engine.config.relative_tolerance, 1e-8);
assert_eq!(engine.config.min_improvement_factor, 0.8);
}
}
#[test]
fn test_refinement_methods() {
let mut solver = SuperLuDistSolver::new();
solver.enable_iterative_refinement();
assert!(solver.refinement_engine.is_some());
solver.set_residual_method(ResidualMethod::Scaled);
if let Some(ref engine) = solver.refinement_engine {
assert_eq!(engine.residual_method, ResidualMethod::Scaled);
}
solver.disable_iterative_refinement();
assert!(solver.refinement_engine.is_none());
}
#[test]
fn test_refinement_engine_creation() {
let config = RefinementConfig::default();
let engine = RefinementEngine::new(config, ResidualMethod::ComponentWise);
assert_eq!(engine.residual_method, ResidualMethod::ComponentWise);
assert!(engine.last_stats.is_none());
}
#[test]
fn test_refinement_convergence_criteria() {
let engine = RefinementEngine::with_defaults();
assert!(engine.check_convergence(1e-13, 1e-6, 1));
assert!(engine.check_convergence(1e-7, 1e-1, 1));
assert!(!engine.check_convergence(1e-4, 1e-6, 1));
assert!(!engine.check_convergence(1e-13, 1e-6, 0));
}
#[test]
fn test_distributed_sparse_matvec() {
let matrix = CsrMatrix::from_csr(
3,
3,
vec![0, 2, 4, 6],
vec![0, 1, 1, 2, 0, 2],
vec![2.0, 1.0, 3.0, 1.0, 1.0, 4.0],
);
let x = vec![1.0, 2.0, 3.0];
let mut y = vec![0.0; 3];
let comm = UniverseComm::NoComm(NoComm);
RefinementEngine::distributed_sparse_matvec_static(&matrix, &x, &mut y, &comm).unwrap();
assert_eq!(y, vec![4.0, 9.0, 13.0]);
}
#[test]
fn test_refinement_stats() {
let stats = RefinementStats {
iterations: 3,
initial_residual_norm: 1e-3,
final_residual_norm: 1e-12,
residual_history: vec![1e-3, 1e-6, 1e-9, 1e-12],
converged: true,
convergence_reason: RefinementConvergence::AbsoluteTolerance,
refinement_time: 0.001,
};
assert_eq!(stats.iterations, 3);
assert!(stats.converged);
assert_eq!(stats.residual_history.len(), 4);
assert!(matches!(
stats.convergence_reason,
RefinementConvergence::AbsoluteTolerance
));
}
#[test]
fn test_residual_methods() {
assert_eq!(ResidualMethod::Standard, ResidualMethod::Standard);
assert_ne!(ResidualMethod::Standard, ResidualMethod::Scaled);
assert_ne!(ResidualMethod::Scaled, ResidualMethod::ComponentWise);
}
#[test]
fn test_refinement_workspace_setup() {
let mut engine = RefinementEngine::with_defaults();
let n = 100;
engine.setup_workspace(n);
assert_eq!(engine.residual_workspace.len(), n);
assert_eq!(engine.correction_workspace.len(), n);
assert_eq!(engine.matvec_workspace.len(), n);
}
#[test]
fn test_vector_norm_computation() {
let comm = UniverseComm::NoComm(NoComm);
let vector = vec![3.0, 4.0, 0.0];
let norm = RefinementEngine::compute_vector_norm_static(&vector, &comm).unwrap();
assert!((norm - 5.0).abs() < 1e-10);
}
#[test]
fn test_refinement_convergence_variants() {
let reasons = [
RefinementConvergence::AbsoluteTolerance,
RefinementConvergence::RelativeTolerance,
RefinementConvergence::MaxIterations,
RefinementConvergence::Stagnation,
RefinementConvergence::Divergence,
];
for (i, reason1) in reasons.iter().enumerate() {
for (j, reason2) in reasons.iter().enumerate() {
if i == j {
assert_eq!(reason1, reason2);
} else {
assert_ne!(reason1, reason2);
}
}
}
}
#[test]
fn test_iterative_refinement_integration() {
let _matrix = CsrMatrix::from_csr(
3,
3,
vec![0, 2, 4, 6],
vec![0, 1, 1, 2, 0, 2],
vec![2.0, 1.0, 3.0, 1.0, 1.0, 4.0],
);
let _b = vec![6.0, 8.0, 10.0];
let _x = vec![0.0; 3];
let mut solver = SuperLuDistSolver::new();
solver.enable_iterative_refinement();
let config = RefinementConfig {
max_iterations: 3,
tolerance: 1e-10,
relative_tolerance: 1e-8,
min_improvement_factor: 0.9,
};
solver.set_refinement_config(config);
solver.set_residual_method(ResidualMethod::Standard);
let _comm = UniverseComm::NoComm(NoComm);
assert!(solver.refinement_engine.is_some());
if let Some(ref engine) = solver.refinement_engine {
assert_eq!(engine.config.max_iterations, 3);
assert_eq!(engine.config.tolerance, 1e-10);
assert_eq!(engine.residual_method, ResidualMethod::Standard);
}
assert!(solver.refinement_stats().is_none());
}
#[test]
fn test_refinement_residual_scaling() {
let matrix = CsrMatrix::from_csr(2, 2, vec![0, 1, 2], vec![0, 1], vec![1.0, 1.0]);
let rhs = vec![2.0, 3.0];
let solution = vec![1.0, 1.0];
let mut residual = vec![0.0; 2];
let mut matvec_workspace = vec![0.0; 2];
let comm = UniverseComm::NoComm(NoComm);
RefinementEngine::compute_residual_static(
&matrix,
&rhs,
&solution,
&mut residual,
&mut matvec_workspace,
ResidualMethod::Standard,
&comm,
)
.unwrap();
assert_eq!(residual, vec![1.0, 2.0]);
RefinementEngine::compute_residual_static(
&matrix,
&rhs,
&solution,
&mut residual,
&mut matvec_workspace,
ResidualMethod::Scaled,
&comm,
)
.unwrap();
let rhs_norm = (4.0 + 9.0_f64).sqrt();
assert!((residual[0] - 1.0 / rhs_norm).abs() < 1e-10);
assert!((residual[1] - 2.0 / rhs_norm).abs() < 1e-10);
}
#[test]
fn test_memory_pool_basic_operations() {
let mut pool = MemoryPool::new(4, 1024 * 1024);
let vec1 = pool.get_f64_vector(100);
assert_eq!(vec1.len(), 100);
let vec2 = pool.get_usize_vector(50);
assert_eq!(vec2.len(), 50);
pool.return_f64_vector(vec1);
pool.return_usize_vector(vec2);
let vec3 = pool.get_f64_vector(100);
assert_eq!(vec3.len(), 100);
let stats = pool.memory_stats();
assert!(stats.f64_vectors_pooled > 0 || stats.f64_vectors_pooled == 0); }
#[test]
fn test_memory_pool_size_limits() {
let mut pool = MemoryPool::new(2, 1024);
let vec1 = pool.get_f64_vector(10);
pool.return_f64_vector(vec1);
let vec2 = pool.get_f64_vector(10);
pool.return_f64_vector(vec2);
let stats = pool.memory_stats();
assert!(stats.total_memory_bytes <= 1024);
}
#[test]
fn test_comm_buffer_manager() {
let mut manager = CommBufferManager::new(1000, 10);
let send_buf = manager.get_send_buffer(0, 100);
assert_eq!(send_buf.len(), 100);
let recv_buf = manager.get_recv_buffer(1, 200);
assert_eq!(recv_buf.len(), 200);
let local_buf = manager.get_local_buffer(50);
assert_eq!(local_buf.len(), 50);
let send_buf2 = manager.get_send_buffer(0, 100);
assert_eq!(send_buf2.len(), 100);
manager.clear_buffers();
let stats = manager.memory_stats();
assert!(stats.total_memory_bytes < 1024 * 1024); }
#[test]
fn test_superlu_dist_workspace_creation() {
let workspace = SuperLuDistWorkspace::new();
let stats = workspace.memory_stats();
assert_eq!(stats.temp_vectors_count, 0);
assert_eq!(stats.total_memory, 0);
}
#[test]
fn test_workspace_config_variants() {
let strategies = [
PreallocationStrategy::None,
PreallocationStrategy::MatrixSize,
PreallocationStrategy::ProcessGrid,
PreallocationStrategy::BlockSize,
PreallocationStrategy::Full,
];
for strategy in strategies {
let config = WorkspaceConfig {
preallocation_strategy: strategy,
..Default::default()
};
let workspace = SuperLuDistWorkspace::with_config(config);
let _stats = workspace.memory_stats();
}
}
#[test]
fn test_workspace_temp_vector_management() {
let mut workspace = SuperLuDistWorkspace::new();
let vec1 = workspace.get_temp_vector("test_vec", 100);
assert_eq!(vec1.len(), 100);
vec1[0] = 42.0;
let vec2 = workspace.get_temp_vector("test_vec", 100);
assert_eq!(vec2.len(), 100);
assert_eq!(vec2[0], 0.0);
workspace.return_temp_vector("test_vec");
let stats = workspace.memory_stats();
assert!(stats.temp_vectors_count <= 1); }
#[test]
fn test_workspace_setup_for_problem() {
let mut workspace = SuperLuDistWorkspace::new();
let comm = UniverseComm::NoComm(NoComm);
let process_grid = ProcessGrid::new_auto(&comm).unwrap();
workspace
.setup_for_problem(1000, &process_grid, 64)
.unwrap();
let stats = workspace.memory_stats();
assert!(stats.total_memory > 0);
}
#[test]
fn test_workspace_optimization() {
let mut workspace = SuperLuDistWorkspace::new();
workspace.get_temp_vector("temp1", 100);
workspace.get_temp_vector("temp2", 200);
let stats_before = workspace.memory_stats();
workspace.optimize();
let stats_after = workspace.memory_stats();
assert!(stats_after.total_memory <= stats_before.total_memory);
}
#[test]
fn test_solver_workspace_configuration() {
let mut solver = SuperLuDistSolver::new();
solver
.set_workspace_memory_limit(2048)
.set_aggressive_memory_reuse(true)
.set_preallocation_strategy(PreallocationStrategy::MatrixSize);
assert_eq!(solver.workspace_config.memory_limit_mb, 2048);
assert_eq!(solver.workspace_config.aggressive_reuse, true);
assert_eq!(
solver.workspace_config.preallocation_strategy,
PreallocationStrategy::MatrixSize
);
}
#[test]
fn test_workspace_memory_stats() {
let mut solver = SuperLuDistSolver::new();
solver.set_workspace_memory_limit(512);
assert!(solver.workspace_memory_stats().is_none());
assert!(solver.workspace_memory_stats().is_none());
}
#[test]
fn test_workspace_cleanup_detection() {
let solver = SuperLuDistSolver::new();
assert!(!solver.workspace_needs_cleanup());
let mut solver_mut = solver;
assert!(solver_mut.optimize_workspace().is_ok());
assert!(solver_mut.clear_workspace_temp_data().is_ok());
}
#[test]
fn test_workspace_memory_efficiency() {
let mut workspace = SuperLuDistWorkspace::with_config(WorkspaceConfig {
memory_limit_mb: 1, aggressive_reuse: true,
preallocation_strategy: PreallocationStrategy::None,
..Default::default()
});
workspace.get_temp_vector("small1", 10);
workspace.get_temp_vector("small2", 10);
let stats = workspace.memory_stats();
let limit_bytes = 1024 * 1024;
assert!(stats.total_memory < limit_bytes);
workspace.clear_temp_data();
let stats_after = workspace.memory_stats();
assert!(stats_after.total_memory <= stats.total_memory);
}
#[test]
fn test_solve_workspace_integration() {
let workspace = SuperLuDistWorkspace::new();
let process_vectors = HashMap::new();
let global_vectors = HashMap::new();
let solve_workspace = SolveWorkspace {
workspace,
process_vectors,
global_vectors,
};
let _stats = solve_workspace.workspace.memory_stats();
assert!(true);
}
#[test]
fn test_superlu_dist_builder_pattern() {
let solver = SuperLuDistSolver::builder()
.diagonal_pivot_threshold(0.2)
.column_permutation(ColumnPermutation::Metis)
.row_permutation(RowPermutation::LargeDiag)
.iterative_refinement(IterativeRefinement::Double)
.print_level(1)
.replace_tiny_pivots(true)
.static_pivoting(false)
.process_grid(2, 2)
.panel_size(32)
.enable_3d_factorization(false, None)
.memory_tradeoff_factor(1.5)
.max_concurrent_panels(2)
.async_panel_updates(true)
.workspace_memory_limit(1024)
.aggressive_memory_reuse(true)
.preallocation_strategy(PreallocationStrategy::MatrixSize)
.build();
assert_eq!(solver.options.diagonal_pivot_threshold, 0.2);
assert_eq!(solver.options.column_permutation, ColumnPermutation::Metis);
assert_eq!(solver.options.row_permutation, RowPermutation::LargeDiag);
assert_eq!(
solver.options.iterative_refinement,
IterativeRefinement::Double
);
assert_eq!(solver.options.print_level, 1);
assert_eq!(solver.options.replace_tiny_pivots, true);
assert_eq!(solver.options.static_pivoting, false);
assert_eq!(solver.options.process_grid, Some((2, 2)));
assert_eq!(solver.options.panel_size, Some(32));
assert_eq!(solver.options.enable_3d_factorization, false);
assert_eq!(solver.options.memory_tradeoff_factor, 1.5);
assert_eq!(solver.options.max_concurrent_panels, 2);
assert_eq!(solver.options.async_panel_updates, true);
assert_eq!(solver.workspace_config.memory_limit_mb, 1024);
assert_eq!(solver.workspace_config.aggressive_reuse, true);
assert_eq!(
solver.workspace_config.preallocation_strategy,
PreallocationStrategy::MatrixSize
);
}
#[test]
fn test_superlu_dist_fluent_configuration() {
let mut solver = SuperLuDistSolver::new();
solver
.set_diagonal_pivot_threshold(0.3)
.set_column_permutation(ColumnPermutation::ParMetis)
.set_row_permutation(RowPermutation::NoRowPerm)
.set_iterative_refinement(IterativeRefinement::Single)
.set_print_level(2)
.set_replace_tiny_pivots(false)
.set_static_pivoting(true)
.set_process_grid(4, 1)
.set_panel_size(64)
.set_3d_factorization(true, Some(2))
.set_memory_tradeoff(2.0)
.set_max_concurrent_panels(4)
.set_async_panel_updates(false)
.set_workspace_memory_limit(2048)
.set_aggressive_memory_reuse(false)
.set_preallocation_strategy(PreallocationStrategy::ProcessGrid);
assert_eq!(solver.options.diagonal_pivot_threshold, 0.3);
assert_eq!(
solver.options.column_permutation,
ColumnPermutation::ParMetis
);
assert_eq!(solver.options.row_permutation, RowPermutation::NoRowPerm);
assert_eq!(
solver.options.iterative_refinement,
IterativeRefinement::Single
);
assert_eq!(solver.options.print_level, 2);
assert_eq!(solver.options.replace_tiny_pivots, false);
assert_eq!(solver.options.static_pivoting, true);
assert_eq!(solver.options.process_grid, Some((4, 1)));
assert_eq!(solver.options.panel_size, Some(64));
assert_eq!(solver.options.enable_3d_factorization, true);
assert_eq!(solver.options.process_grid_3d_depth, Some(2));
assert_eq!(solver.options.memory_tradeoff_factor, 2.0);
assert_eq!(solver.options.max_concurrent_panels, 4);
assert_eq!(solver.options.async_panel_updates, false);
assert_eq!(solver.workspace_config.memory_limit_mb, 2048);
assert_eq!(solver.workspace_config.aggressive_reuse, false);
assert_eq!(
solver.workspace_config.preallocation_strategy,
PreallocationStrategy::ProcessGrid
);
}
#[test]
fn test_superlu_dist_builder_with_refinement() {
let refinement_config = RefinementConfig {
max_iterations: 5,
tolerance: 1e-10,
relative_tolerance: 1e-8,
min_improvement_factor: 0.95,
};
let solver = SuperLuDistSolver::builder()
.diagonal_pivot_threshold(0.1)
.iterative_refinement(IterativeRefinement::Double)
.refinement_config(refinement_config)
.residual_method(ResidualMethod::Scaled)
.build();
assert!(solver.refinement_engine.is_some());
assert_eq!(
solver.options.iterative_refinement,
IterativeRefinement::Double
);
if let Some(ref engine) = solver.refinement_engine {
let config = engine.config();
assert_eq!(config.max_iterations, 5);
assert_eq!(config.tolerance, 1e-10);
assert_eq!(config.relative_tolerance, 1e-8);
assert_eq!(config.min_improvement_factor, 0.95);
}
}
#[test]
fn test_superlu_dist_auto_process_grid() {
let mut solver = SuperLuDistSolver::new();
solver.set_process_grid(2, 3);
assert_eq!(solver.options.process_grid, Some((2, 3)));
solver.set_process_grid_auto();
assert_eq!(solver.options.process_grid, None);
}
#[test]
fn test_superlu_dist_complete_options_replacement() {
let new_options = SuperLuDistOptions {
process_grid: Some((1, 4)),
column_permutation: ColumnPermutation::Natural,
diagonal_pivot_threshold: 0.01,
replace_tiny_pivots: false,
iterative_refinement: IterativeRefinement::Extra,
print_level: 3,
static_pivoting: true,
row_permutation: RowPermutation::User,
panel_size: Some(128),
enable_3d_factorization: true,
process_grid_3d_depth: Some(4),
memory_tradeoff_factor: 3.0,
max_concurrent_panels: 8,
async_panel_updates: true,
};
let solver = SuperLuDistSolver::new().with_complete_options(new_options.clone());
assert_eq!(solver.options.process_grid, new_options.process_grid);
assert_eq!(
solver.options.column_permutation,
new_options.column_permutation
);
assert_eq!(
solver.options.diagonal_pivot_threshold,
new_options.diagonal_pivot_threshold
);
assert_eq!(
solver.options.replace_tiny_pivots,
new_options.replace_tiny_pivots
);
assert_eq!(
solver.options.iterative_refinement,
new_options.iterative_refinement
);
assert_eq!(solver.options.print_level, new_options.print_level);
assert_eq!(solver.options.static_pivoting, new_options.static_pivoting);
assert_eq!(solver.options.row_permutation, new_options.row_permutation);
assert_eq!(solver.options.panel_size, new_options.panel_size);
assert_eq!(
solver.options.enable_3d_factorization,
new_options.enable_3d_factorization
);
assert_eq!(
solver.options.process_grid_3d_depth,
new_options.process_grid_3d_depth
);
assert_eq!(
solver.options.memory_tradeoff_factor,
new_options.memory_tradeoff_factor
);
assert_eq!(
solver.options.max_concurrent_panels,
new_options.max_concurrent_panels
);
assert_eq!(
solver.options.async_panel_updates,
new_options.async_panel_updates
);
}
#[test]
fn test_superlu_dist_linear_solver_error_handling() {
use crate::parallel::{NoComm, UniverseComm};
let matrix = CsrMatrix::from_csr(
2,
2, vec![0, 1, 2],
vec![0, 1],
vec![1.0, 1.0],
);
let b = vec![1.0, 2.0];
let mut x = vec![0.0, 0.0]; let comm = UniverseComm::NoComm(NoComm);
let mut solver = SuperLuDistSolver::new();
assert!(solver.data.is_none());
let result = crate::solver::legacy::LinearSolver::solve(
&mut solver,
&matrix,
None,
&b,
&mut x,
crate::preconditioner::PcSide::Left,
&comm,
None,
None,
);
assert!(result.is_ok());
assert!(solver.data.is_some());
}
#[test]
fn test_superlu_dist_builder_defaults() {
let builder = SuperLuDistBuilder::new();
let solver = builder.build();
assert_eq!(solver.options.diagonal_pivot_threshold, 1.0);
assert_eq!(solver.options.column_permutation, ColumnPermutation::MmdAta);
assert_eq!(solver.options.row_permutation, RowPermutation::LargeDiag);
assert_eq!(
solver.options.iterative_refinement,
IterativeRefinement::Double
);
assert_eq!(solver.options.print_level, 0);
assert_eq!(solver.options.replace_tiny_pivots, false);
assert_eq!(solver.options.static_pivoting, false);
assert_eq!(solver.options.process_grid, None);
assert_eq!(solver.options.panel_size, None);
assert_eq!(solver.options.enable_3d_factorization, false);
assert_eq!(solver.options.process_grid_3d_depth, None);
assert_eq!(solver.options.memory_tradeoff_factor, 1.0);
assert_eq!(solver.options.max_concurrent_panels, 1);
assert_eq!(solver.options.async_panel_updates, false);
assert!(solver.refinement_engine.is_none());
}
#[test]
#[cfg(feature = "superlu3d")]
fn test_request_schedule_3d() {
use crate::parallel::NoComm;
let comm = UniverseComm::NoComm(NoComm);
let grid2d = ProcessGrid {
prows: 1,
pcols: 1,
my_prow: 0,
my_pcol: 0,
my_rank: 0,
total_procs: 1,
};
let dist = BlockCyclicDistribution::new(grid2d.clone(), 0, 0, 1, 1);
let grid3d = ProcessGrid3D {
prows: 1,
pcols: 1,
pdepth: 3,
my_prow: 0,
my_pcol: 0,
my_pdepth: 0,
my_rank: 0,
total_procs: 1,
};
let mut solve_data = TriangularSolveData {
local_solution_blocks: vec![],
comm_buffer: vec![],
pending_requests: vec![],
block_owners: vec![0],
block_sizes: vec![],
local_l_factors: vec![],
local_u_factors: vec![],
dependency_graph: vec![],
};
let block_id = 5usize;
DistributedTriangularSolver::start_nonblocking_broadcast(
&mut solve_data,
&[0.0],
block_id,
&dist,
CommPattern::PointToPoint,
&comm,
Some(&grid3d),
)
.unwrap();
let mut tags: Vec<usize> = solve_data
.pending_requests
.iter()
.map(|r| r.meta.tag)
.collect();
tags.sort_unstable();
let expected = vec![(block_id << 8) + 1, (block_id << 8) + 2];
assert_eq!(tags, expected);
}
}
#[cfg(test)]
mod send_sync_checks {
use super::*;
fn assert_send<T: Send>() {}
#[allow(dead_code)]
fn assert_send_sync<T: Send + Sync>() {}
#[test]
fn superlu_types_send() {
assert_send::<SuperLuDistSolver>();
assert_send::<SuperLuDistData>();
assert_send::<SuperLuDistWorkspace>();
assert_send::<MemoryPool>();
assert_send::<CommBufferManager>();
}
}