use crate::ordering::elimination_tree::EliminationTree;
use crate::symbolic::profiler::SymbolicProfiler;
use crate::symbolic::small_leaf::SmallLeafParams;
use std::sync::{Arc, Mutex};
#[derive(Clone)]
pub struct SupernodeParams {
pub nemin: usize,
pub preprocess: OrderingPreprocess,
pub small_leaf: SmallLeafParams,
pub amalgamation_strategy: AmalgamationStrategy,
pub symbolic_profiler: Option<Arc<Mutex<SymbolicProfiler>>>,
}
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
pub enum AmalgamationStrategy {
Adjacency,
Renumber,
#[default]
Auto,
}
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)]
pub enum OrderingPreprocess {
None,
LdltCompress,
#[default]
Auto,
}
impl Default for SupernodeParams {
fn default() -> Self {
Self {
nemin: 16,
preprocess: OrderingPreprocess::Auto,
small_leaf: SmallLeafParams::default(),
amalgamation_strategy: AmalgamationStrategy::default(),
symbolic_profiler: None,
}
}
}
#[derive(Debug, Clone)]
pub struct Supernode {
pub first_col: usize,
pub ncol: usize,
pub nrow: usize,
pub row_indices: Vec<usize>,
pub children: Vec<usize>,
}
impl Supernode {
#[inline]
pub fn ncol(&self) -> usize {
self.ncol
}
#[inline]
pub fn contrib_nrow(&self) -> usize {
self.nrow - self.ncol
}
#[inline]
pub fn contrib_size(&self) -> usize {
let cn = self.contrib_nrow();
cn * cn
}
}
pub const AUTO_MULTI_CHILD_FRAC_THRESHOLD: f64 = 0.05;
pub fn pick_amalgamation_strategy(etree: &EliminationTree) -> AmalgamationStrategy {
let n = etree.n;
if n == 0 {
return AmalgamationStrategy::Adjacency;
}
let mut child_count = vec![0usize; n];
for &p in &etree.parent {
if let Some(par) = p {
child_count[par] += 1;
}
}
let n_leaves = child_count.iter().filter(|&&c| c == 0).count();
let n_internal = n - n_leaves;
if n_internal == 0 {
return AmalgamationStrategy::Adjacency;
}
let n_multi_child = child_count.iter().filter(|&&c| c >= 2).count();
let multi_child_frac = n_multi_child as f64 / n_internal as f64;
if multi_child_frac < AUTO_MULTI_CHILD_FRAC_THRESHOLD {
AmalgamationStrategy::Adjacency
} else {
AmalgamationStrategy::Renumber
}
}
pub fn find_supernodes(
etree: &EliminationTree,
col_counts: &[usize],
params: &SupernodeParams,
) -> Vec<Supernode> {
let n = etree.n;
if n == 0 {
return Vec::new();
}
let fund = find_fundamental_supernodes(etree, col_counts);
let snode_starts = fund.snode_starts;
let mut snode_ncols = fund.snode_ncols;
let snode_parent = fund.snode_parent;
let n_snodes = snode_starts.len();
let mut merged_into = vec![None::<usize>; n_snodes];
let mut snode_first_col: Vec<usize> = snode_starts.clone();
let reverse = matches!(params.amalgamation_strategy, AmalgamationStrategy::Renumber);
let order: Box<dyn Iterator<Item = usize>> = if reverse {
Box::new((0..n_snodes).rev())
} else {
Box::new(0..n_snodes)
};
for s in order {
let sp = snode_parent[s];
if let Some(p) = sp {
if find_root(s, &merged_into) != s {
continue; }
let root_s = find_root(s, &merged_into);
let root_p = find_root(p, &merged_into);
if root_s == root_p {
continue;
}
let s_first = snode_first_col[root_s];
let s_ncol = snode_ncols[root_s];
let p_first = snode_first_col[root_p];
if s_first + s_ncol != p_first {
continue;
}
let child_ncol = snode_ncols[root_s];
let parent_ncol = snode_ncols[root_p];
let trivial_chain = parent_ncol == 1 && {
let child_last = s_first + s_ncol - 1;
col_counts[p_first] + 1 == col_counts[child_last]
};
let size_based = child_ncol < params.nemin && parent_ncol < params.nemin;
if trivial_chain || size_based {
merged_into[root_s] = Some(root_p);
snode_ncols[root_p] = child_ncol + parent_ncol;
snode_first_col[root_p] = s_first;
}
}
}
let mut final_snodes: Vec<Supernode> = Vec::new();
let mut new_snode_id = vec![0usize; n_snodes];
for s in 0..n_snodes {
if merged_into[s].is_some() {
continue;
}
let first_col = snode_first_col[s];
let ncol = snode_ncols[s];
let nrow = col_counts[first_col].max(ncol);
let row_indices = (first_col..first_col + nrow).collect();
new_snode_id[s] = final_snodes.len();
final_snodes.push(Supernode {
first_col,
ncol,
nrow,
row_indices,
children: Vec::new(),
});
}
for s in 0..n_snodes {
if merged_into[s].is_some() {
continue;
}
if let Some(p) = snode_parent[s] {
let root_p = find_root(p, &merged_into);
if root_p != s {
let new_child = new_snode_id[s];
let new_parent = new_snode_id[root_p];
final_snodes[new_parent].children.push(new_child);
}
}
}
final_snodes
}
fn find_root(s: usize, merged_into: &[Option<usize>]) -> usize {
let mut node = s;
while let Some(parent) = merged_into[node] {
node = parent;
}
node
}
pub(crate) struct FundamentalSupernodes {
pub(crate) snode_starts: Vec<usize>,
pub(crate) snode_ncols: Vec<usize>,
pub(crate) snode_parent: Vec<Option<usize>>,
}
pub(crate) fn find_fundamental_supernodes(
etree: &EliminationTree,
col_counts: &[usize],
) -> FundamentalSupernodes {
let n = etree.n;
if n == 0 {
return FundamentalSupernodes {
snode_starts: Vec::new(),
snode_ncols: Vec::new(),
snode_parent: Vec::new(),
};
}
let mut snode_id = vec![0usize; n];
let mut snode_starts: Vec<usize> = Vec::new();
let mut n_children = vec![0usize; n];
for j in 0..n {
if let Some(p) = etree.parent[j] {
n_children[p] += 1;
}
}
snode_starts.push(0);
snode_id[0] = 0;
for j in 1..n {
let same_snode = etree.parent[j - 1] == Some(j)
&& col_counts[j] + 1 == col_counts[j - 1]
&& n_children[j] == 1;
if same_snode {
snode_id[j] = snode_id[j - 1];
} else {
snode_id[j] = snode_starts.len();
snode_starts.push(j);
}
}
let n_snodes = snode_starts.len();
let mut snode_ncols = vec![0usize; n_snodes];
let mut snode_parent: Vec<Option<usize>> = vec![None; n_snodes];
for j in 0..n {
snode_ncols[snode_id[j]] += 1;
}
for s in 0..n_snodes {
let last_col = snode_starts[s] + snode_ncols[s] - 1;
if let Some(p) = etree.parent[last_col] {
snode_parent[s] = Some(snode_id[p]);
}
}
FundamentalSupernodes {
snode_starts,
snode_ncols,
snode_parent,
}
}
pub(crate) fn predict_merges(
etree: &EliminationTree,
col_counts: &[usize],
params: &SupernodeParams,
) -> Vec<bool> {
let n = etree.n;
let mut bias = vec![false; n];
if n == 0 {
return bias;
}
let fund = find_fundamental_supernodes(etree, col_counts);
let n_snodes = fund.snode_starts.len();
for s in 0..n_snodes {
let p = match fund.snode_parent[s] {
Some(p) => p,
None => continue,
};
let child_ncol = fund.snode_ncols[s];
let parent_ncol = fund.snode_ncols[p];
let s_first = fund.snode_starts[s];
let p_first = fund.snode_starts[p];
let child_last = s_first + child_ncol - 1;
let trivial_chain = parent_ncol == 1 && col_counts[p_first] + 1 == col_counts[child_last];
let size_based = child_ncol < params.nemin && parent_ncol < params.nemin;
if trivial_chain || size_based {
for b in bias.iter_mut().skip(s_first).take(child_ncol) {
*b = true;
}
}
}
bias
}
#[cfg(test)]
mod tests {
use super::*;
use crate::sparse::csc::CscMatrix;
use crate::symbolic::column_counts::column_counts;
#[test]
fn test_supernodes_tridiagonal() {
let m =
CscMatrix::from_triplets(4, &[0, 1, 1, 2, 2, 3, 3], &[0, 0, 1, 1, 2, 2, 3], &[1.0; 7])
.unwrap();
let pat = m.symmetric_pattern();
let etree = EliminationTree::from_pattern(&pat);
let counts = column_counts(&pat, &etree);
let params = SupernodeParams {
nemin: 1,
..Default::default()
};
let snodes = find_supernodes(&etree, &counts, ¶ms);
assert_eq!(snodes.len(), 3);
let total_cols: usize = snodes.iter().map(|s| s.ncol()).sum();
assert_eq!(total_cols, 4);
}
#[test]
fn test_supernodes_tridiagonal_amalgamated() {
let m =
CscMatrix::from_triplets(4, &[0, 1, 1, 2, 2, 3, 3], &[0, 0, 1, 1, 2, 2, 3], &[1.0; 7])
.unwrap();
let pat = m.symmetric_pattern();
let etree = EliminationTree::from_pattern(&pat);
let counts = column_counts(&pat, &etree);
let params = SupernodeParams {
nemin: 32,
..Default::default()
};
let snodes = find_supernodes(&etree, &counts, ¶ms);
let total_cols: usize = snodes.iter().map(|s| s.ncol()).sum();
assert_eq!(total_cols, 4);
assert_eq!(snodes.len(), 1);
}
#[test]
fn test_supernodes_dense() {
let m = CscMatrix::from_triplets(3, &[0, 1, 2, 1, 2, 2], &[0, 0, 0, 1, 1, 2], &[1.0; 6])
.unwrap();
let pat = m.symmetric_pattern();
let etree = EliminationTree::from_pattern(&pat);
let counts = column_counts(&pat, &etree);
let params = SupernodeParams {
nemin: 1,
..Default::default()
};
let snodes = find_supernodes(&etree, &counts, ¶ms);
assert_eq!(snodes.len(), 1);
assert_eq!(snodes[0].ncol(), 3);
assert_eq!(snodes[0].nrow, 3);
assert_eq!(snodes[0].contrib_size(), 0); }
#[test]
fn test_supernodes_block_diagonal() {
let m = CscMatrix::from_triplets(4, &[0, 1, 1, 2, 3, 3], &[0, 0, 1, 2, 2, 3], &[1.0; 6])
.unwrap();
let pat = m.symmetric_pattern();
let etree = EliminationTree::from_pattern(&pat);
let counts = column_counts(&pat, &etree);
let params = SupernodeParams {
nemin: 1,
..Default::default()
};
let snodes = find_supernodes(&etree, &counts, ¶ms);
assert_eq!(snodes.len(), 2);
assert_eq!(snodes[0].ncol(), 2);
assert_eq!(snodes[1].ncol(), 2);
}
#[test]
fn test_supernodes_diagonal_no_amalg() {
let m = CscMatrix::from_triplets(4, &[0, 1, 2, 3], &[0, 1, 2, 3], &[1.0; 4]).unwrap();
let pat = m.symmetric_pattern();
let etree = EliminationTree::from_pattern(&pat);
let counts = column_counts(&pat, &etree);
let params = SupernodeParams {
nemin: 1,
..Default::default()
};
let snodes = find_supernodes(&etree, &counts, ¶ms);
assert_eq!(snodes.len(), 4);
}
#[test]
fn test_supernodes_total_columns() {
let m = CscMatrix::from_triplets(
5,
&[0, 1, 2, 3, 4, 1, 2, 3, 4],
&[0, 0, 0, 0, 0, 1, 2, 3, 4],
&[1.0; 9],
)
.unwrap();
let pat = m.symmetric_pattern();
let etree = EliminationTree::from_pattern(&pat);
let counts = column_counts(&pat, &etree);
for nemin in [1, 5, 32] {
let params = SupernodeParams {
nemin,
..Default::default()
};
let snodes = find_supernodes(&etree, &counts, ¶ms);
let total: usize = snodes.iter().map(|s| s.ncol()).sum();
assert_eq!(total, 5, "nemin={}: total columns {} != 5", nemin, total);
}
}
#[test]
fn test_supernode_children_valid() {
let m = CscMatrix::from_triplets(
5,
&[0, 1, 2, 3, 4, 1, 2, 3, 4],
&[0, 0, 0, 0, 0, 1, 2, 3, 4],
&[1.0; 9],
)
.unwrap();
let pat = m.symmetric_pattern();
let etree = EliminationTree::from_pattern(&pat);
let counts = column_counts(&pat, &etree);
let params = SupernodeParams {
nemin: 1,
..Default::default()
};
let snodes = find_supernodes(&etree, &counts, ¶ms);
for (i, s) in snodes.iter().enumerate() {
for &child in &s.children {
assert!(child < snodes.len(), "invalid child index");
assert!(
child < i,
"child {} should come before parent {} in postorder",
child,
i
);
}
}
}
}